!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief Routines to calculate RI-GPW-MP2 energy using pw
!> \par History
!>      06.2012 created [Mauro Del Ben]
!>      03.2019 Refactored from mp2_ri_gpw [Frederick Stein]
! **************************************************************************************************
MODULE mp2_ri_gpw
   USE cp_log_handling,                 ONLY: cp_to_string
   USE dgemm_counter_types,             ONLY: dgemm_counter_init,&
                                              dgemm_counter_start,&
                                              dgemm_counter_stop,&
                                              dgemm_counter_type,&
                                              dgemm_counter_write
   USE group_dist_types,                ONLY: get_group_dist,&
                                              group_dist_d1_type,&
                                              maxsize,&
                                              release_group_dist
   USE kinds,                           ONLY: dp,&
                                              int_8
   USE libint_2c_3c,                    ONLY: compare_potential_types
   USE local_gemm_api,                  ONLY: LOCAL_GEMM_PU_GPU
   USE machine,                         ONLY: m_flush,&
                                              m_memory,&
                                              m_walltime
   USE message_passing,                 ONLY: mp_comm_type,&
                                              mp_para_env_type
   USE mp2_ri_grad_util,                ONLY: complete_gamma
   USE mp2_types,                       ONLY: mp2_type,&
                                              three_dim_real_array

!$ USE OMP_LIB, ONLY: omp_get_max_threads, omp_get_thread_num
#include "./base/base_uses.f90"

   IMPLICIT NONE

   PRIVATE

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'mp2_ri_gpw'

   PUBLIC :: mp2_ri_gpw_compute_en

CONTAINS

! **************************************************************************************************
!> \brief ...
!> \param Emp2_Cou ...
!> \param Emp2_EX ...
!> \param Emp2_S ...
!> \param Emp2_T ...
!> \param BIb_C ...
!> \param mp2_env ...
!> \param para_env ...
!> \param para_env_sub ...
!> \param color_sub ...
!> \param gd_array ...
!> \param gd_B_virtual ...
!> \param Eigenval ...
!> \param nmo ...
!> \param homo ...
!> \param dimen_RI ...
!> \param unit_nr ...
!> \param calc_forces ...
!> \param calc_ex ...
! **************************************************************************************************
   SUBROUTINE mp2_ri_gpw_compute_en(Emp2_Cou, Emp2_EX, Emp2_S, Emp2_T, BIb_C, mp2_env, para_env, para_env_sub, color_sub, &
                                    gd_array, gd_B_virtual, &
                                    Eigenval, nmo, homo, dimen_RI, unit_nr, calc_forces, calc_ex)
      REAL(KIND=dp), INTENT(INOUT)                       :: Emp2_Cou, Emp2_EX, Emp2_S, Emp2_T
      TYPE(three_dim_real_array), DIMENSION(:), &
         INTENT(INOUT)                                   :: BIb_C
      TYPE(mp2_type)                                     :: mp2_env
      TYPE(mp_para_env_type), INTENT(IN), POINTER        :: para_env, para_env_sub
      INTEGER, INTENT(IN)                                :: color_sub
      TYPE(group_dist_d1_type), INTENT(INOUT)            :: gd_array
      INTEGER, DIMENSION(:), INTENT(IN)                  :: homo
      INTEGER, INTENT(IN)                                :: nmo
      REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: Eigenval
      TYPE(group_dist_d1_type), DIMENSION(SIZE(homo)), &
         INTENT(INOUT)                                   :: gd_B_virtual
      INTEGER, INTENT(IN)                                :: dimen_RI, unit_nr
      LOGICAL, INTENT(IN)                                :: calc_forces, calc_ex

      CHARACTER(LEN=*), PARAMETER :: routineN = 'mp2_ri_gpw_compute_en'

      INTEGER :: a, a_global, b, b_global, block_size, decil, end_point, handle, handle2, handle3, &
         iiB, ij_counter, ij_counter_send, ij_index, integ_group_size, ispin, jjB, jspin, &
         max_ij_pairs, my_block_size, my_group_L_end, my_group_L_size, my_group_L_size_orig, &
         my_group_L_start, my_i, my_ij_pairs, my_j, my_new_group_L_size, ngroup, nspins, &
         num_integ_group, proc_receive, proc_send, proc_shift, rec_B_size, rec_B_virtual_end, &
         rec_B_virtual_start, rec_L_size, send_B_size, send_B_virtual_end, send_B_virtual_start, &
         send_i, send_ij_index, send_j, start_point, tag, total_ij_pairs
      INTEGER, ALLOCATABLE, DIMENSION(:) :: integ_group_pos2color_sub, my_B_size, &
         my_B_virtual_end, my_B_virtual_start, num_ij_pairs, sizes_array_orig, virtual
      INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: ij_map
      INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: ranges_info_array
      LOGICAL                                            :: my_alpha_beta_case, my_beta_beta_case, &
                                                            my_open_shell_SS
      REAL(KIND=dp)                                      :: amp_fac, my_Emp2_Cou, my_Emp2_EX, &
                                                            sym_fac, t_new, t_start
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), TARGET   :: buffer_1D
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
         TARGET                                          :: local_ab, local_ba, t_ab
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
         TARGET                                          :: local_i_aL, local_j_aL, Y_i_aP, Y_j_aP
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
         POINTER                                         :: external_ab, external_i_aL
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :, :), &
         POINTER                                         :: BI_C_rec
      TYPE(dgemm_counter_type)                           :: dgemm_counter
      TYPE(mp_comm_type)                                 :: comm_exchange, comm_rep
      TYPE(three_dim_real_array), ALLOCATABLE, &
         DIMENSION(:)                                    :: B_ia_Q

      CALL timeset(routineN, handle)

      nspins = SIZE(homo)

      ALLOCATE (virtual(nspins))
      virtual(:) = nmo - homo(:)

      ALLOCATE (my_B_size(nspins), my_B_virtual_start(nspins), my_B_virtual_end(nspins))
      DO ispin = 1, nspins
         CALL get_group_dist(gd_B_virtual(ispin), para_env_sub%mepos, &
                             my_B_virtual_start(ispin), my_B_virtual_end(ispin), my_B_size(ispin))
      END DO

      CALL get_group_dist(gd_array, color_sub, my_group_L_start, my_group_L_end, my_group_L_size)

      CALL dgemm_counter_init(dgemm_counter, unit_nr, mp2_env%ri_mp2%print_dgemm_info)

      ! local_gemm_ctx has a very footprint the first time this routine is
      ! called.
      CALL mp2_env%local_gemm_ctx%create(LOCAL_GEMM_PU_GPU)
      CALL mp2_env%local_gemm_ctx%set_op_threshold_gpu(128*128*128*2)

      CALL mp2_ri_get_integ_group_size( &
         mp2_env, para_env, para_env_sub, gd_array, gd_B_virtual, &
         homo, dimen_RI, unit_nr, &
         integ_group_size, ngroup, &
         num_integ_group, virtual, calc_forces)

      ! now create a group that contains all the proc that have the same virtual starting point
      ! in the integ group
      CALL mp2_ri_create_group( &
         para_env, para_env_sub, color_sub, &
         gd_array%sizes, calc_forces, &
         integ_group_size, my_group_L_end, &
         my_group_L_size, my_group_L_size_orig, my_group_L_start, my_new_group_L_size, &
         integ_group_pos2color_sub, sizes_array_orig, &
         ranges_info_array, comm_exchange, comm_rep, num_integ_group)

      ! We cannot fix the tag because of the recv routine
      tag = 42

      DO jspin = 1, nspins

         CALL replicate_iaK_2intgroup(BIb_C(jspin)%array, comm_exchange, comm_rep, &
                                      homo(jspin), gd_array%sizes, my_B_size(jspin), &
                                      my_group_L_size, ranges_info_array)

         DO ispin = 1, jspin

            IF (unit_nr > 0) THEN
               IF (nspins == 1) THEN
                  WRITE (unit_nr, *) "Start loop run"
               ELSE IF (ispin == 1 .AND. jspin == 1) THEN
                  WRITE (unit_nr, *) "Start loop run alpha-alpha"
               ELSE IF (ispin == 1 .AND. jspin == 2) THEN
                  WRITE (unit_nr, *) "Start loop run alpha-beta"
               ELSE IF (ispin == 2 .AND. jspin == 2) THEN
                  WRITE (unit_nr, *) "Start loop run beta-beta"
               END IF
               CALL m_flush(unit_nr)
            END IF

            my_open_shell_SS = (nspins == 2) .AND. (ispin == jspin)

            ! t_ab = amp_fac*(:,a|:,b)-(:,b|:,a)
            ! If we calculate the gradient we need to distinguish
            ! between alpha-alpha and beta-beta cases for UMP2

            my_beta_beta_case = .FALSE.
            my_alpha_beta_case = .FALSE.
            IF (ispin /= jspin) THEN
               my_alpha_beta_case = .TRUE.
            ELSE IF (my_open_shell_SS) THEN
               IF (ispin == 2) my_beta_beta_case = .TRUE.
            END IF

            amp_fac = mp2_env%scale_S + mp2_env%scale_T
            IF (my_alpha_beta_case .OR. my_open_shell_SS) amp_fac = mp2_env%scale_T

            CALL mp2_ri_allocate_no_blk(local_ab, t_ab, mp2_env, homo, virtual, my_B_size, &
                                        my_group_L_size, calc_forces, ispin, jspin, local_ba)

            CALL mp2_ri_get_block_size( &
               mp2_env, para_env, para_env_sub, gd_array, gd_B_virtual(ispin:jspin), &
               homo(ispin:jspin), virtual(ispin:jspin), dimen_RI, unit_nr, block_size, &
               ngroup, num_integ_group, my_open_shell_ss, calc_forces, buffer_1D)

            ! *****************************************************************
            ! **********  REPLICATION-BLOCKED COMMUNICATION SCHEME  ***********
            ! *****************************************************************
            ! introduce block size, the number of occupied orbitals has to be a
            ! multiple of the block size

            ! Calculate the maximum number of ij pairs that have to be computed
            ! among groups
            CALL mp2_ri_communication(my_alpha_beta_case, total_ij_pairs, homo(ispin), homo(jspin), &
                                      block_size, ngroup, ij_map, color_sub, my_ij_pairs, my_open_shell_SS, unit_nr)

            ALLOCATE (num_ij_pairs(0:comm_exchange%num_pe - 1))
            CALL comm_exchange%allgather(my_ij_pairs, num_ij_pairs)

            max_ij_pairs = MAXVAL(num_ij_pairs)

            ! start real stuff
            CALL mp2_ri_allocate_blk(dimen_RI, my_B_size, block_size, local_i_aL, &
                                     local_j_aL, calc_forces, Y_i_aP, Y_j_aP, ispin, jspin)

            CALL timeset(routineN//"_RI_loop", handle2)
            my_Emp2_Cou = 0.0_dp
            my_Emp2_EX = 0.0_dp
            t_start = m_walltime()
            DO ij_index = 1, max_ij_pairs

               ! Prediction is unreliable if we are in the first step of the loop
               IF (unit_nr > 0 .AND. ij_index > 1) THEN
                  decil = ij_index*10/max_ij_pairs
                  IF (decil /= (ij_index - 1)*10/max_ij_pairs) THEN
                     t_new = m_walltime()
                     t_new = (t_new - t_start)/60.0_dp*(max_ij_pairs - ij_index + 1)/(ij_index - 1)
                     WRITE (unit_nr, FMT="(T3,A)") "Percentage of finished loop: "// &
                        cp_to_string(decil*10)//". Minutes left: "//cp_to_string(t_new)
                     CALL m_flush(unit_nr)
                  END IF
               END IF

               IF (calc_forces) THEN
                  Y_i_aP = 0.0_dp
                  Y_j_aP = 0.0_dp
               END IF

               IF (ij_index <= my_ij_pairs) THEN
                  ! We have work to do
                  ij_counter = (ij_index - MIN(1, color_sub))*ngroup + color_sub
                  my_i = ij_map(1, ij_counter)
                  my_j = ij_map(2, ij_counter)
                  my_block_size = ij_map(3, ij_counter)

                  local_i_aL = 0.0_dp
                  CALL fill_local_i_aL(local_i_aL(:, :, 1:my_block_size), ranges_info_array(:, :, comm_exchange%mepos), &
                                       BIb_C(ispin)%array(:, :, my_i:my_i + my_block_size - 1))

                  local_j_aL = 0.0_dp
                  CALL fill_local_i_aL(local_j_aL(:, :, 1:my_block_size), ranges_info_array(:, :, comm_exchange%mepos), &
                                       BIb_C(jspin)%array(:, :, my_j:my_j + my_block_size - 1))

                  ! collect data from other proc
                  CALL timeset(routineN//"_comm", handle3)
                  DO proc_shift = 1, comm_exchange%num_pe - 1
                     proc_send = MODULO(comm_exchange%mepos + proc_shift, comm_exchange%num_pe)
                     proc_receive = MODULO(comm_exchange%mepos - proc_shift, comm_exchange%num_pe)

                     send_ij_index = num_ij_pairs(proc_send)

                     CALL get_group_dist(gd_array, proc_receive, sizes=rec_L_size)

                     IF (ij_index <= send_ij_index) THEN
                        ij_counter_send = (ij_index - MIN(1, integ_group_pos2color_sub(proc_send)))*ngroup + &
                                          integ_group_pos2color_sub(proc_send)
                        send_i = ij_map(1, ij_counter_send)
                        send_j = ij_map(2, ij_counter_send)

                        ! occupied i
                        BI_C_rec(1:rec_L_size, 1:my_B_size(ispin), 1:my_block_size) => &
                           buffer_1D(1:rec_L_size*my_B_size(ispin)*my_block_size)
                        BI_C_rec = 0.0_dp
                        CALL comm_exchange%sendrecv(BIb_C(ispin)%array(:, :, send_i:send_i + my_block_size - 1), &
                                                    proc_send, BI_C_rec, proc_receive, tag)

                        CALL fill_local_i_aL(local_i_aL(:, :, 1:my_block_size), ranges_info_array(:, :, proc_receive), &
                                             BI_C_rec(:, 1:my_B_size(ispin), :))

                        ! occupied j
                        BI_C_rec(1:rec_L_size, 1:my_B_size(jspin), 1:my_block_size) => &
                           buffer_1D(1:INT(rec_L_size, int_8)*my_B_size(jspin)*my_block_size)
                        BI_C_rec = 0.0_dp
                        CALL comm_exchange%sendrecv(BIb_C(jspin)%array(:, :, send_j:send_j + my_block_size - 1), &
                                                    proc_send, BI_C_rec, proc_receive, tag)

                        CALL fill_local_i_aL(local_j_aL(:, :, 1:my_block_size), ranges_info_array(:, :, proc_receive), &
                                             BI_C_rec(:, 1:my_B_size(jspin), :))

                     ELSE
                        ! we send nothing while we know that we have to receive something

                        ! occupied i
                        BI_C_rec(1:rec_L_size, 1:my_B_size(ispin), 1:my_block_size) => &
                           buffer_1D(1:INT(rec_L_size, int_8)*my_B_size(ispin)*my_block_size)
                        BI_C_rec = 0.0_dp
                        CALL comm_exchange%recv(BI_C_rec, proc_receive, tag)

                        CALL fill_local_i_aL(local_i_aL(:, :, 1:my_block_size), ranges_info_array(:, :, proc_receive), &
                                             BI_C_rec(:, 1:my_B_size(ispin), 1:my_block_size))

                        ! occupied j
                        BI_C_rec(1:rec_L_size, 1:my_B_size(jspin), 1:my_block_size) => &
                           buffer_1D(1:INT(rec_L_size, int_8)*my_B_size(jspin)*my_block_size)
                        BI_C_rec = 0.0_dp
                        CALL comm_exchange%recv(BI_C_rec, proc_receive, tag)

                        CALL fill_local_i_aL(local_j_aL(:, :, 1:my_block_size), ranges_info_array(:, :, proc_receive), &
                                             BI_C_rec(:, 1:my_B_size(jspin), 1:my_block_size))

                     END IF

                  END DO

                  CALL timestop(handle3)

                  ! loop over the block elements
                  DO iiB = 1, my_block_size
                     DO jjB = 1, my_block_size
                        CALL timeset(routineN//"_expansion", handle3)
                        ASSOCIATE (my_local_i_aL => local_i_aL(:, :, iiB), my_local_j_aL => local_j_aL(:, :, jjB))

                           ! calculate the integrals (ia|jb) strating from my local data ...
                           local_ab = 0.0_dp
                           IF ((my_alpha_beta_case) .AND. (calc_forces)) THEN
                              local_ba = 0.0_dp
                           END IF
                           CALL dgemm_counter_start(dgemm_counter)
                           CALL mp2_env%local_gemm_ctx%gemm('T', 'N', my_B_size(ispin), my_B_size(jspin), dimen_RI, 1.0_dp, &
                                                            my_local_i_aL, dimen_RI, my_local_j_aL, dimen_RI, &
                                                           0.0_dp, local_ab(my_B_virtual_start(ispin):my_B_virtual_end(ispin), :), &
                                                            my_B_size(ispin))
                           ! Additional integrals only for alpha_beta case and forces
                           IF (my_alpha_beta_case .AND. calc_forces) THEN
                              local_ba(my_B_virtual_start(jspin):my_B_virtual_end(jspin), :) = &
                                 TRANSPOSE(local_ab(my_B_virtual_start(ispin):my_B_virtual_end(ispin), :))
                           END IF
                           ! ... and from the other of my subgroup
                           DO proc_shift = 1, para_env_sub%num_pe - 1
                              proc_send = MODULO(para_env_sub%mepos + proc_shift, para_env_sub%num_pe)
                              proc_receive = MODULO(para_env_sub%mepos - proc_shift, para_env_sub%num_pe)

                              CALL get_group_dist(gd_B_virtual(ispin), proc_receive, rec_B_virtual_start, &
                                                  rec_B_virtual_end, rec_B_size)

                              external_i_aL(1:dimen_RI, 1:rec_B_size) => buffer_1D(1:INT(dimen_RI, int_8)*rec_B_size)
                              external_i_aL = 0.0_dp

                              CALL para_env_sub%sendrecv(my_local_i_aL, proc_send, &
                                                         external_i_aL, proc_receive, tag)

                              CALL mp2_env%local_gemm_ctx%gemm( &
                                 'T', 'N', rec_B_size, my_B_size(jspin), dimen_RI, 1.0_dp, &
                                 external_i_aL, dimen_RI, my_local_j_aL, dimen_RI, &
                                 0.0_dp, local_ab(rec_B_virtual_start:rec_B_virtual_end, 1:my_B_size(jspin)), rec_B_size)

                              ! Additional integrals only for alpha_beta case and forces
                              IF (my_alpha_beta_case .AND. calc_forces) THEN

                                 CALL get_group_dist(gd_B_virtual(jspin), proc_receive, rec_B_virtual_start, &
                                                     rec_B_virtual_end, rec_B_size)

                                 external_i_aL(1:dimen_RI, 1:rec_B_size) => buffer_1D(1:INT(dimen_RI, int_8)*rec_B_size)
                                 external_i_aL = 0.0_dp

                                 CALL para_env_sub%sendrecv(my_local_j_aL, proc_send, &
                                                            external_i_aL, proc_receive, tag)

                                 CALL mp2_env%local_gemm_ctx%gemm('T', 'N', rec_B_size, my_B_size(ispin), dimen_RI, 1.0_dp, &
                                                                  external_i_aL, dimen_RI, my_local_i_aL, dimen_RI, &
                                            0.0_dp, local_ba(rec_B_virtual_start:rec_B_virtual_end, 1:my_B_size(ispin)), rec_B_size)
                              END IF
                           END DO
                           IF (my_alpha_beta_case .AND. calc_forces) THEN
                              ! Is just an approximation, but the call does not allow it, it ought to be (virtual_i*B_size_j+virtual_j*B_size_i)*dimen_RI
                              CALL dgemm_counter_stop(dgemm_counter, virtual(ispin), my_B_size(ispin) + my_B_size(jspin), dimen_RI)
                           ELSE
                              CALL dgemm_counter_stop(dgemm_counter, virtual(ispin), my_B_size(jspin), dimen_RI)
                           END IF
                           CALL timestop(handle3)

                           !sample peak memory
                           CALL m_memory()

                           CALL timeset(routineN//"_ener", handle3)
                           ! calculate coulomb only MP2
                           sym_fac = 2.0_dp
                           IF (my_i == my_j) sym_fac = 1.0_dp
                           IF (my_alpha_beta_case) sym_fac = 0.5_dp
                           DO b = 1, my_B_size(jspin)
                              b_global = b + my_B_virtual_start(jspin) - 1
                              DO a = 1, virtual(ispin)
                                 my_Emp2_Cou = my_Emp2_Cou - sym_fac*2.0_dp*local_ab(a, b)**2/ &
                                               (Eigenval(homo(ispin) + a, ispin) + Eigenval(homo(jspin) + b_global, jspin) - &
                                                Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, jspin))
                              END DO
                           END DO

                           IF (calc_ex) THEN
                              ! contract integrals with orbital energies for exchange MP2 energy
                              ! starting with local ...
                              IF (calc_forces .AND. (.NOT. my_alpha_beta_case)) t_ab = 0.0_dp
                              DO b = 1, my_B_size(ispin)
                                 b_global = b + my_B_virtual_start(ispin) - 1
                                 DO a = 1, my_B_size(ispin)
                                    a_global = a + my_B_virtual_start(ispin) - 1
                                    my_Emp2_Ex = my_Emp2_Ex + sym_fac*local_ab(a_global, b)*local_ab(b_global, a)/ &
                                              (Eigenval(homo(ispin) + a_global, ispin) + Eigenval(homo(ispin) + b_global, ispin) - &
                                                  Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, ispin))
                                    IF (calc_forces .AND. (.NOT. my_alpha_beta_case)) THEN
                                     t_ab(a_global, b) = -(amp_fac*local_ab(a_global, b) - mp2_env%scale_T*local_ab(b_global, a))/ &
                                                           (Eigenval(homo(ispin) + a_global, ispin) + &
                                                            Eigenval(homo(ispin) + b_global, ispin) - &
                                                            Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, ispin))
                                    END IF
                                 END DO
                              END DO
                              ! ... and then with external data
                              DO proc_shift = 1, para_env_sub%num_pe - 1
                                 proc_send = MODULO(para_env_sub%mepos + proc_shift, para_env_sub%num_pe)
                                 proc_receive = MODULO(para_env_sub%mepos - proc_shift, para_env_sub%num_pe)

                                 CALL get_group_dist(gd_B_virtual(ispin), proc_receive, &
                                                     rec_B_virtual_start, rec_B_virtual_end, rec_B_size)
                                 CALL get_group_dist(gd_B_virtual(ispin), proc_send, &
                                                     send_B_virtual_start, send_B_virtual_end, send_B_size)

                                 external_ab(1:my_B_size(ispin), 1:rec_B_size) => &
                                    buffer_1D(1:INT(rec_B_size, int_8)*my_B_size(ispin))
                                 external_ab = 0.0_dp

                      CALL para_env_sub%sendrecv(local_ab(send_B_virtual_start:send_B_virtual_end, 1:my_B_size(ispin)), proc_send, &
                                                            external_ab(1:my_B_size(ispin), 1:rec_B_size), proc_receive, tag)

                                 DO b = 1, my_B_size(ispin)
                                    b_global = b + my_B_virtual_start(ispin) - 1
                                    DO a = 1, rec_B_size
                                       a_global = a + rec_B_virtual_start - 1
                                       my_Emp2_Ex = my_Emp2_Ex + sym_fac*local_ab(a_global, b)*external_ab(b, a)/ &
                                              (Eigenval(homo(ispin) + a_global, ispin) + Eigenval(homo(ispin) + b_global, ispin) - &
                                                     Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, ispin))
                                       IF (calc_forces .AND. (.NOT. my_alpha_beta_case)) &
                                         t_ab(a_global, b) = -(amp_fac*local_ab(a_global, b) - mp2_env%scale_T*external_ab(b, a))/ &
                                                              (Eigenval(homo(ispin) + a_global, ispin) + &
                                                               Eigenval(homo(ispin) + b_global, ispin) - &
                                                               Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, ispin))
                                    END DO
                                 END DO
                              END DO
                           END IF
                           CALL timestop(handle3)

                           IF (calc_forces) THEN
                              ! update P_ab, Gamma_P_ia
                              CALL mp2_update_P_gamma(mp2_env, para_env_sub, gd_B_virtual, &
                                                      Eigenval, homo, dimen_RI, iiB, jjB, my_B_size, &
                                                      my_B_virtual_end, my_B_virtual_start, my_i, my_j, virtual, &
                                                      local_ab, t_ab, my_local_i_aL, my_local_j_aL, &
                                                      my_open_shell_ss, Y_i_aP(:, :, iiB), Y_j_aP(:, :, jjB), local_ba, &
                                                      ispin, jspin, dgemm_counter, buffer_1D)

                           END IF

                        END ASSOCIATE

                     END DO ! jjB
                  END DO ! iiB

               ELSE
                  ! We need it later in case of gradients
                  my_block_size = 1

                  CALL timeset(routineN//"_comm", handle3)
                  ! No work to do and we know that we have to receive nothing, but send something
                  ! send data to other proc
                  DO proc_shift = 1, comm_exchange%num_pe - 1
                     proc_send = MODULO(comm_exchange%mepos + proc_shift, comm_exchange%num_pe)
                     proc_receive = MODULO(comm_exchange%mepos - proc_shift, comm_exchange%num_pe)

                     send_ij_index = num_ij_pairs(proc_send)

                     IF (ij_index <= send_ij_index) THEN
                        ! something to send
                        ij_counter_send = (ij_index - MIN(1, integ_group_pos2color_sub(proc_send)))*ngroup + &
                                          integ_group_pos2color_sub(proc_send)
                        send_i = ij_map(1, ij_counter_send)
                        send_j = ij_map(2, ij_counter_send)

                        ! occupied i
                        CALL comm_exchange%send(BIb_C(ispin)%array(:, :, send_i:send_i + my_block_size - 1), &
                                                proc_send, tag)
                        ! occupied j
                        CALL comm_exchange%send(BIb_C(jspin)%array(:, :, send_j:send_j + my_block_size - 1), &
                                                proc_send, tag)
                     END IF
                  END DO
                  CALL timestop(handle3)
               END IF

               ! redistribute gamma
               IF (calc_forces) THEN
                  CALL mp2_redistribute_gamma(mp2_env%ri_grad%Gamma_P_ia(ispin)%array, ij_index, my_B_size(ispin), &
                                              my_block_size, my_group_L_size, my_i, my_ij_pairs, ngroup, &
                                              num_integ_group, integ_group_pos2color_sub, num_ij_pairs, &
                                              ij_map, ranges_info_array, Y_i_aP(:, :, 1:my_block_size), comm_exchange, &
                                              gd_array%sizes, 1, buffer_1D)
                  CALL mp2_redistribute_gamma(mp2_env%ri_grad%Gamma_P_ia(jspin)%array, ij_index, my_B_size(jspin), &
                                              my_block_size, my_group_L_size, my_j, my_ij_pairs, ngroup, &
                                              num_integ_group, integ_group_pos2color_sub, num_ij_pairs, &
                                              ij_map, ranges_info_array, Y_j_aP(:, :, 1:my_block_size), comm_exchange, &
                                              gd_array%sizes, 2, buffer_1D)
               END IF

            END DO
            CALL timestop(handle2)

            DEALLOCATE (local_i_aL)
            DEALLOCATE (local_j_aL)
            DEALLOCATE (ij_map)
            DEALLOCATE (num_ij_pairs)
            DEALLOCATE (local_ab)

            IF (calc_forces) THEN
               DEALLOCATE (Y_i_aP)
               DEALLOCATE (Y_j_aP)
               IF (ALLOCATED(t_ab)) THEN
                  DEALLOCATE (t_ab)
               END IF
               DEALLOCATE (local_ba)

               ! here we check if there are almost degenerate ij
               ! pairs and we update P_ij with these contribution.
               ! If all pairs are degenerate with each other this step will scale O(N^6),
               ! if the number of degenerate pairs scales linearly with the system size
               ! this step will scale O(N^5).
               ! Start counting the number of almost degenerate ij pairs according
               ! to eps_canonical
               CALL quasi_degenerate_P_ij( &
                  mp2_env, Eigenval(:, ispin:jspin), homo(ispin:jspin), virtual(ispin:jspin), my_open_shell_ss, &
                  my_beta_beta_case, Bib_C(ispin:jspin), unit_nr, dimen_RI, &
                  my_B_size(ispin:jspin), ngroup, my_group_L_size, &
                  color_sub, ranges_info_array, comm_exchange, para_env_sub, para_env, &
                  my_B_virtual_start(ispin:jspin), my_B_virtual_end(ispin:jspin), gd_array%sizes, gd_B_virtual(ispin:jspin), &
                  integ_group_pos2color_sub, dgemm_counter, buffer_1D)

            END IF

            DEALLOCATE (buffer_1D)

            ! Dereplicate BIb_C and Gamma_P_ia to save memory
            ! These matrices will not be needed in that fashion anymore
            ! B_ia_Q will needed later
            IF (calc_forces .AND. jspin == nspins) THEN
               IF (.NOT. ALLOCATED(B_ia_Q)) ALLOCATE (B_ia_Q(nspins))
               ALLOCATE (B_ia_Q(ispin)%array(homo(ispin), my_B_size(ispin), my_group_L_size_orig))
               B_ia_Q(ispin)%array = 0.0_dp
               DO jjB = 1, homo(ispin)
                  DO iiB = 1, my_B_size(ispin)
                     B_ia_Q(ispin)%array(jjB, iiB, 1:my_group_L_size_orig) = &
                        BIb_C(ispin)%array(1:my_group_L_size_orig, iiB, jjB)
                  END DO
               END DO
               DEALLOCATE (BIb_C(ispin)%array)

               ! sum Gamma and dereplicate
               ALLOCATE (BIb_C(ispin)%array(my_B_size(ispin), homo(ispin), my_group_L_size_orig))
               DO proc_shift = 1, comm_rep%num_pe - 1
                  ! invert order
                  proc_send = MODULO(comm_rep%mepos - proc_shift, comm_rep%num_pe)
                  proc_receive = MODULO(comm_rep%mepos + proc_shift, comm_rep%num_pe)

                  start_point = ranges_info_array(3, proc_shift, comm_exchange%mepos)
                  end_point = ranges_info_array(4, proc_shift, comm_exchange%mepos)

                  CALL comm_rep%sendrecv(mp2_env%ri_grad%Gamma_P_ia(ispin)%array(:, :, start_point:end_point), &
                                         proc_send, BIb_C(ispin)%array, proc_receive, tag)
!$OMP PARALLEL WORKSHARE DEFAULT(NONE) &
!$OMP          SHARED(mp2_env,BIb_C,ispin,homo,my_B_size,my_group_L_size_orig)
                  mp2_env%ri_grad%Gamma_P_ia(ispin)%array(:, :, 1:my_group_L_size_orig) = &
                     mp2_env%ri_grad%Gamma_P_ia(ispin)%array(:, :, 1:my_group_L_size_orig) &
                     + BIb_C(ispin)%array(:, :, :)
!$OMP END PARALLEL WORKSHARE
               END DO

               BIb_C(ispin)%array(:, :, :) = mp2_env%ri_grad%Gamma_P_ia(ispin)%array(:, :, 1:my_group_L_size_orig)
               DEALLOCATE (mp2_env%ri_grad%Gamma_P_ia(ispin)%array)
               CALL MOVE_ALLOC(BIb_C(ispin)%array, mp2_env%ri_grad%Gamma_P_ia(ispin)%array)
            ELSE IF (jspin == nspins) THEN
               DEALLOCATE (BIb_C(ispin)%array)
            END IF

            CALL para_env%sum(my_Emp2_Cou)
            CALL para_env%sum(my_Emp2_Ex)

            IF (my_open_shell_SS .OR. my_alpha_beta_case) THEN
               IF (my_alpha_beta_case) THEN
                  Emp2_S = Emp2_S + my_Emp2_Cou
                  Emp2_Cou = Emp2_Cou + my_Emp2_Cou
               ELSE
                  my_Emp2_Cou = my_Emp2_Cou*0.25_dp
                  my_Emp2_EX = my_Emp2_EX*0.5_dp
                  Emp2_T = Emp2_T + my_Emp2_Cou + my_Emp2_EX
                  Emp2_Cou = Emp2_Cou + my_Emp2_Cou
                  Emp2_EX = Emp2_EX + my_Emp2_EX
               END IF
            ELSE
               Emp2_Cou = Emp2_Cou + my_Emp2_Cou
               Emp2_EX = Emp2_EX + my_Emp2_EX
            END IF
         END DO

      END DO

      DEALLOCATE (integ_group_pos2color_sub)
      DEALLOCATE (ranges_info_array)

      CALL comm_exchange%free()
      CALL comm_rep%free()

      IF (calc_forces) THEN
         ! recover original information (before replication)
         DEALLOCATE (gd_array%sizes)
         iiB = SIZE(sizes_array_orig)
         ALLOCATE (gd_array%sizes(0:iiB - 1))
         gd_array%sizes(:) = sizes_array_orig
         DEALLOCATE (sizes_array_orig)

         ! Remove replication from BIb_C and reorder the matrix
         my_group_L_size = my_group_L_size_orig

         ! B_ia_Q(ispin)%array will be deallocated inside of complete_gamma
         DO ispin = 1, nspins
            CALL complete_gamma(mp2_env, B_ia_Q(ispin)%array, dimen_RI, homo(ispin), &
                                virtual(ispin), para_env, para_env_sub, ngroup, &
                                my_group_L_size, my_group_L_start, my_group_L_end, &
                                my_B_size(ispin), my_B_virtual_start(ispin), &
                                gd_array, gd_B_virtual(ispin), &
                                ispin)
         END DO
         DEALLOCATE (B_ia_Q)

         IF (nspins == 1) mp2_env%ri_grad%P_ab(1)%array(:, :) = mp2_env%ri_grad%P_ab(1)%array(:, :)*2.0_dp
         BLOCK
            TYPE(mp_comm_type) :: comm
            CALL comm%from_split(para_env, para_env_sub%mepos)
            DO ispin = 1, nspins
               ! P_ab is only replicated over all subgroups
               CALL comm%sum(mp2_env%ri_grad%P_ab(ispin)%array)
               ! P_ij is replicated over all processes
               CALL para_env%sum(mp2_env%ri_grad%P_ij(ispin)%array)
            END DO
            CALL comm%free()
         END BLOCK
      END IF

      CALL release_group_dist(gd_array)
      DO ispin = 1, nspins
         IF (ALLOCATED(BIb_C(ispin)%array)) DEALLOCATE (BIb_C(ispin)%array)
         CALL release_group_dist(gd_B_virtual(ispin))
      END DO

      ! We do not need this matrix later, so deallocate it here to safe memory
      IF (calc_forces) DEALLOCATE (mp2_env%ri_grad%PQ_half)
      IF (calc_forces .AND. .NOT. compare_potential_types(mp2_env%ri_metric, mp2_env%potential_parameter)) &
         DEALLOCATE (mp2_env%ri_grad%operator_half)

      CALL dgemm_counter_write(dgemm_counter, para_env)

      ! release memory allocated by local_gemm when run on GPU. local_gemm_ctx is null on cpu only runs
      CALL mp2_env%local_gemm_ctx%destroy()
      CALL timestop(handle)

   END SUBROUTINE mp2_ri_gpw_compute_en

! **************************************************************************************************
!> \brief ...
!> \param local_i_aL ...
!> \param ranges_info_array ...
!> \param BIb_C_rec ...
! **************************************************************************************************
   SUBROUTINE fill_local_i_aL(local_i_aL, ranges_info_array, BIb_C_rec)
      REAL(KIND=dp), DIMENSION(:, :, :), INTENT(INOUT)   :: local_i_aL
      INTEGER, DIMENSION(:, :), INTENT(IN)               :: ranges_info_array
      REAL(KIND=dp), DIMENSION(:, :, :), INTENT(IN)      :: BIb_C_rec

      CHARACTER(LEN=*), PARAMETER                        :: routineN = 'fill_local_i_aL'

      INTEGER                                            :: end_point, handle, irep, Lend_pos, &
                                                            Lstart_pos, start_point

      CALL timeset(routineN, handle)

      DO irep = 1, SIZE(ranges_info_array, 2)
         Lstart_pos = ranges_info_array(1, irep)
         Lend_pos = ranges_info_array(2, irep)
         start_point = ranges_info_array(3, irep)
         end_point = ranges_info_array(4, irep)

!$OMP PARALLEL WORKSHARE DEFAULT(NONE) &
!$OMP          SHARED(BIb_C_rec,local_i_aL,Lstart_pos,Lend_pos,start_point,end_point)
         local_i_aL(Lstart_pos:Lend_pos, :, :) = BIb_C_rec(start_point:end_point, :, :)
!$OMP END PARALLEL WORKSHARE
      END DO

      CALL timestop(handle)

   END SUBROUTINE fill_local_i_aL

! **************************************************************************************************
!> \brief ...
!> \param local_i_aL ...
!> \param ranges_info_array ...
!> \param BIb_C_rec ...
! **************************************************************************************************
   SUBROUTINE fill_local_i_aL_2D(local_i_aL, ranges_info_array, BIb_C_rec)
      REAL(KIND=dp), DIMENSION(:, :), INTENT(INOUT)      :: local_i_aL
      INTEGER, DIMENSION(:, :), INTENT(IN)               :: ranges_info_array
      REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: BIb_C_rec

      CHARACTER(LEN=*), PARAMETER :: routineN = 'fill_local_i_aL_2D'

      INTEGER                                            :: end_point, handle, irep, Lend_pos, &
                                                            Lstart_pos, start_point

      CALL timeset(routineN, handle)

      DO irep = 1, SIZE(ranges_info_array, 2)
         Lstart_pos = ranges_info_array(1, irep)
         Lend_pos = ranges_info_array(2, irep)
         start_point = ranges_info_array(3, irep)
         end_point = ranges_info_array(4, irep)

!$OMP PARALLEL WORKSHARE DEFAULT(NONE) &
!$OMP          SHARED(BIb_C_rec,local_i_aL,Lstart_pos,Lend_pos,start_point,end_point)
         local_i_aL(Lstart_pos:Lend_pos, :) = BIb_C_rec(start_point:end_point, :)
!$OMP END PARALLEL WORKSHARE
      END DO

      CALL timestop(handle)

   END SUBROUTINE fill_local_i_aL_2D

! **************************************************************************************************
!> \brief ...
!> \param BIb_C ...
!> \param comm_exchange ...
!> \param comm_rep ...
!> \param homo ...
!> \param sizes_array ...
!> \param my_B_size ...
!> \param my_group_L_size ...
!> \param ranges_info_array ...
! **************************************************************************************************
   SUBROUTINE replicate_iaK_2intgroup(BIb_C, comm_exchange, comm_rep, homo, sizes_array, my_B_size, &
                                      my_group_L_size, ranges_info_array)
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
         INTENT(INOUT)                                   :: BIb_C
      TYPE(mp_comm_type), INTENT(IN)                     :: comm_exchange, comm_rep
      INTEGER, INTENT(IN)                                :: homo
      INTEGER, DIMENSION(:), INTENT(IN)                  :: sizes_array
      INTEGER, INTENT(IN)                                :: my_B_size, my_group_L_size
      INTEGER, DIMENSION(:, 0:, 0:), INTENT(IN)          :: ranges_info_array

      CHARACTER(LEN=*), PARAMETER :: routineN = 'replicate_iaK_2intgroup'

      INTEGER                                            :: end_point, handle, max_L_size, &
                                                            proc_receive, proc_shift, start_point
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: BIb_C_copy
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :, :)  :: BIb_C_gather

      CALL timeset(routineN, handle)

      ! replication scheme using mpi_allgather
      ! get the max L size of the
      max_L_size = MAXVAL(sizes_array)

      ALLOCATE (BIb_C_copy(max_L_size, my_B_size, homo))
      BIb_C_copy = 0.0_dp
      BIb_C_copy(1:SIZE(BIb_C, 1), 1:my_B_size, 1:homo) = BIb_C

      DEALLOCATE (BIb_C)

      ALLOCATE (BIb_C_gather(max_L_size, my_B_size, homo, 0:comm_rep%num_pe - 1))
      BIb_C_gather = 0.0_dp

      CALL comm_rep%allgather(BIb_C_copy, BIb_C_gather)

      DEALLOCATE (BIb_C_copy)

      ALLOCATE (BIb_C(my_group_L_size, my_B_size, homo))
      BIb_C = 0.0_dp

      ! reorder data
      DO proc_shift = 0, comm_rep%num_pe - 1
         proc_receive = MODULO(comm_rep%mepos - proc_shift, comm_rep%num_pe)

         start_point = ranges_info_array(3, proc_shift, comm_exchange%mepos)
         end_point = ranges_info_array(4, proc_shift, comm_exchange%mepos)

         BIb_C(start_point:end_point, 1:my_B_size, 1:homo) = &
            BIb_C_gather(1:end_point - start_point + 1, 1:my_B_size, 1:homo, proc_receive)

      END DO

      DEALLOCATE (BIb_C_gather)

      CALL timestop(handle)

   END SUBROUTINE replicate_iaK_2intgroup

! **************************************************************************************************
!> \brief ...
!> \param local_ab ...
!> \param t_ab ...
!> \param mp2_env ...
!> \param homo ...
!> \param virtual ...
!> \param my_B_size ...
!> \param my_group_L_size ...
!> \param calc_forces ...
!> \param ispin ...
!> \param jspin ...
!> \param local_ba ...
! **************************************************************************************************
   SUBROUTINE mp2_ri_allocate_no_blk(local_ab, t_ab, mp2_env, homo, virtual, my_B_size, &
                                     my_group_L_size, calc_forces, ispin, jspin, local_ba)
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
         INTENT(OUT)                                     :: local_ab, t_ab
      TYPE(mp2_type)                                     :: mp2_env
      INTEGER, INTENT(IN)                                :: homo(2), virtual(2), my_B_size(2), &
                                                            my_group_L_size
      LOGICAL, INTENT(IN)                                :: calc_forces
      INTEGER, INTENT(IN)                                :: ispin, jspin
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
         INTENT(OUT)                                     :: local_ba

      CHARACTER(LEN=*), PARAMETER :: routineN = 'mp2_ri_allocate_no_blk'

      INTEGER                                            :: handle

      CALL timeset(routineN, handle)

      ALLOCATE (local_ab(virtual(ispin), my_B_size(jspin)))
      local_ab = 0.0_dp

      IF (calc_forces) THEN
         IF (.NOT. ALLOCATED(mp2_env%ri_grad%P_ij(jspin)%array)) THEN
            ALLOCATE (mp2_env%ri_grad%P_ij(jspin)%array(homo(ispin), homo(ispin)))
            mp2_env%ri_grad%P_ij(jspin)%array = 0.0_dp
         END IF
         IF (.NOT. ALLOCATED(mp2_env%ri_grad%P_ab(jspin)%array)) THEN
            ALLOCATE (mp2_env%ri_grad%P_ab(jspin)%array(my_B_size(jspin), virtual(jspin)))
            mp2_env%ri_grad%P_ab(jspin)%array = 0.0_dp
         END IF
         IF (.NOT. ALLOCATED(mp2_env%ri_grad%Gamma_P_ia(jspin)%array)) THEN
            ALLOCATE (mp2_env%ri_grad%Gamma_P_ia(jspin)%array(my_B_size(jspin), homo(jspin), my_group_L_size))
            mp2_env%ri_grad%Gamma_P_ia(jspin)%array = 0.0_dp
         END IF

         IF (ispin == jspin) THEN
            ! For non-alpha-beta case we need amplitudes
            ALLOCATE (t_ab(virtual(ispin), my_B_size(jspin)))

            ! That is just a dummy. In that way, we can pass it as array to other routines w/o requirement for allocatable array
            ALLOCATE (local_ba(1, 1))
         ELSE
            ! We need more integrals
            ALLOCATE (local_ba(virtual(jspin), my_B_size(ispin)))
         END IF
      END IF
      !

      CALL timestop(handle)

   END SUBROUTINE mp2_ri_allocate_no_blk

! **************************************************************************************************
!> \brief ...
!> \param dimen_RI ...
!> \param my_B_size ...
!> \param block_size ...
!> \param local_i_aL ...
!> \param local_j_aL ...
!> \param calc_forces ...
!> \param Y_i_aP ...
!> \param Y_j_aP ...
!> \param ispin ...
!> \param jspin ...
! **************************************************************************************************
   SUBROUTINE mp2_ri_allocate_blk(dimen_RI, my_B_size, block_size, &
                                  local_i_aL, local_j_aL, calc_forces, &
                                  Y_i_aP, Y_j_aP, ispin, jspin)
      INTEGER, INTENT(IN)                                :: dimen_RI, my_B_size(2), block_size
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
         INTENT(OUT)                                     :: local_i_aL, local_j_aL
      LOGICAL, INTENT(IN)                                :: calc_forces
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :), &
         INTENT(OUT)                                     :: Y_i_aP, Y_j_aP
      INTEGER, INTENT(IN)                                :: ispin, jspin

      CHARACTER(LEN=*), PARAMETER :: routineN = 'mp2_ri_allocate_blk'

      INTEGER                                            :: handle

      CALL timeset(routineN, handle)

      ALLOCATE (local_i_aL(dimen_RI, my_B_size(ispin), block_size))
      local_i_aL = 0.0_dp
      ALLOCATE (local_j_aL(dimen_RI, my_B_size(jspin), block_size))
      local_j_aL = 0.0_dp

      IF (calc_forces) THEN
         ALLOCATE (Y_i_aP(my_B_size(ispin), dimen_RI, block_size))
         Y_i_aP = 0.0_dp
         ! For  closed-shell, alpha-alpha and beta-beta my_B_size_beta=my_b_size
         ! Not for alpha-beta case: Y_j_aP_beta is sent and received as Y_j_aP
         ALLOCATE (Y_j_aP(my_B_size(jspin), dimen_RI, block_size))
         Y_j_aP = 0.0_dp
      END IF
      !

      CALL timestop(handle)

   END SUBROUTINE mp2_ri_allocate_blk

! **************************************************************************************************
!> \brief ...
!> \param my_alpha_beta_case ...
!> \param total_ij_pairs ...
!> \param homo ...
!> \param homo_beta ...
!> \param block_size ...
!> \param ngroup ...
!> \param ij_map ...
!> \param color_sub ...
!> \param my_ij_pairs ...
!> \param my_open_shell_SS ...
!> \param unit_nr ...
! **************************************************************************************************
   SUBROUTINE mp2_ri_communication(my_alpha_beta_case, total_ij_pairs, homo, homo_beta, &
                                   block_size, ngroup, ij_map, color_sub, my_ij_pairs, my_open_shell_SS, unit_nr)
      LOGICAL, INTENT(IN)                                :: my_alpha_beta_case
      INTEGER, INTENT(OUT)                               :: total_ij_pairs
      INTEGER, INTENT(IN)                                :: homo, homo_beta, block_size, ngroup
      INTEGER, ALLOCATABLE, DIMENSION(:, :), INTENT(OUT) :: ij_map
      INTEGER, INTENT(IN)                                :: color_sub
      INTEGER, INTENT(OUT)                               :: my_ij_pairs
      LOGICAL, INTENT(IN)                                :: my_open_shell_SS
      INTEGER, INTENT(IN)                                :: unit_nr

      CHARACTER(LEN=*), PARAMETER :: routineN = 'mp2_ri_communication'

      INTEGER :: assigned_blocks, first_I_block, first_J_block, handle, iiB, ij_block_counter, &
         ij_counter, jjB, last_i_block, last_J_block, num_block_per_group, num_IJ_blocks, &
         num_IJ_blocks_beta, total_ij_block, total_ij_pairs_blocks
      LOGICAL, ALLOCATABLE, DIMENSION(:, :)              :: ij_marker

! Calculate the maximum number of ij pairs that have to be computed
! among groups

      CALL timeset(routineN, handle)

      IF (.NOT. my_open_shell_ss .AND. .NOT. my_alpha_beta_case) THEN
         total_ij_pairs = homo*(1 + homo)/2
         num_IJ_blocks = homo/block_size - 1

         first_I_block = 1
         last_i_block = block_size*(num_IJ_blocks - 1)

         first_J_block = block_size + 1
         last_J_block = block_size*(num_IJ_blocks + 1)

         ij_block_counter = 0
         DO iiB = first_I_block, last_i_block, block_size
            DO jjB = iiB + block_size, last_J_block, block_size
               ij_block_counter = ij_block_counter + 1
            END DO
         END DO

         total_ij_block = ij_block_counter
         num_block_per_group = total_ij_block/ngroup
         assigned_blocks = num_block_per_group*ngroup

         total_ij_pairs_blocks = assigned_blocks + (total_ij_pairs - assigned_blocks*(block_size**2))

         ALLOCATE (ij_marker(homo, homo))
         ij_marker = .TRUE.
         ALLOCATE (ij_map(3, total_ij_pairs_blocks))
         ij_map = 0
         ij_counter = 0
         my_ij_pairs = 0
         DO iiB = first_I_block, last_i_block, block_size
            DO jjB = iiB + block_size, last_J_block, block_size
               IF (ij_counter + 1 > assigned_blocks) EXIT
               ij_counter = ij_counter + 1
               ij_marker(iiB:iiB + block_size - 1, jjB:jjB + block_size - 1) = .FALSE.
               ij_map(1, ij_counter) = iiB
               ij_map(2, ij_counter) = jjB
               ij_map(3, ij_counter) = block_size
               IF (MOD(ij_counter, ngroup) == color_sub) my_ij_pairs = my_ij_pairs + 1
            END DO
         END DO
         DO iiB = 1, homo
            DO jjB = iiB, homo
               IF (ij_marker(iiB, jjB)) THEN
                  ij_counter = ij_counter + 1
                  ij_map(1, ij_counter) = iiB
                  ij_map(2, ij_counter) = jjB
                  ij_map(3, ij_counter) = 1
                  IF (MOD(ij_counter, ngroup) == color_sub) my_ij_pairs = my_ij_pairs + 1
               END IF
            END DO
         END DO
         DEALLOCATE (ij_marker)

      ELSE IF (.NOT. my_alpha_beta_case) THEN
         ! THese are the cases alpha/alpha and beta/beta
         ! We do not have to consider the diagonal elements
         total_ij_pairs = homo*(homo - 1)/2
         num_IJ_blocks = (homo - 1)/block_size - 1

         first_I_block = 1
         last_i_block = block_size*(num_IJ_blocks - 1)

         ! We shift the blocks to prevent the calculation of the diagonal elements which always give zero
         first_J_block = block_size + 2
         last_J_block = block_size*(num_IJ_blocks + 1) + 1

         ij_block_counter = 0
         DO iiB = first_I_block, last_i_block, block_size
            DO jjB = iiB + block_size + 1, last_J_block, block_size
               ij_block_counter = ij_block_counter + 1
            END DO
         END DO

         total_ij_block = ij_block_counter
         num_block_per_group = total_ij_block/ngroup
         assigned_blocks = num_block_per_group*ngroup

         total_ij_pairs_blocks = assigned_blocks + (total_ij_pairs - assigned_blocks*(block_size**2))

         ALLOCATE (ij_marker(homo, homo))
         ij_marker = .TRUE.
         ALLOCATE (ij_map(3, total_ij_pairs_blocks))
         ij_map = 0
         ij_counter = 0
         my_ij_pairs = 0
         DO iiB = first_I_block, last_i_block, block_size
            DO jjB = iiB + block_size + 1, last_J_block, block_size
               IF (ij_counter + 1 > assigned_blocks) EXIT
               ij_counter = ij_counter + 1
               ij_marker(iiB:iiB + block_size - 1, jjB:jjB + block_size - 1) = .FALSE.
               ij_map(1, ij_counter) = iiB
               ij_map(2, ij_counter) = jjB
               ij_map(3, ij_counter) = block_size
               IF (MOD(ij_counter, ngroup) == color_sub) my_ij_pairs = my_ij_pairs + 1
            END DO
         END DO
         DO iiB = 1, homo
            DO jjB = iiB + 1, homo
               IF (ij_marker(iiB, jjB)) THEN
                  ij_counter = ij_counter + 1
                  ij_map(1, ij_counter) = iiB
                  ij_map(2, ij_counter) = jjB
                  ij_map(3, ij_counter) = 1
                  IF (MOD(ij_counter, ngroup) == color_sub) my_ij_pairs = my_ij_pairs + 1
               END IF
            END DO
         END DO
         DEALLOCATE (ij_marker)

      ELSE
         total_ij_pairs = homo*homo_beta
         num_IJ_blocks = homo/block_size
         num_IJ_blocks_beta = homo_beta/block_size

         first_I_block = 1
         last_i_block = block_size*(num_IJ_blocks - 1)

         first_J_block = 1
         last_J_block = block_size*(num_IJ_blocks_beta - 1)

         ij_block_counter = 0
         DO iiB = first_I_block, last_i_block, block_size
            DO jjB = first_J_block, last_J_block, block_size
               ij_block_counter = ij_block_counter + 1
            END DO
         END DO

         total_ij_block = ij_block_counter
         num_block_per_group = total_ij_block/ngroup
         assigned_blocks = num_block_per_group*ngroup

         total_ij_pairs_blocks = assigned_blocks + (total_ij_pairs - assigned_blocks*(block_size**2))

         ALLOCATE (ij_marker(homo, homo_beta))
         ij_marker = .TRUE.
         ALLOCATE (ij_map(3, total_ij_pairs_blocks))
         ij_map = 0
         ij_counter = 0
         my_ij_pairs = 0
         DO iiB = first_I_block, last_i_block, block_size
            DO jjB = first_J_block, last_J_block, block_size
               IF (ij_counter + 1 > assigned_blocks) EXIT
               ij_counter = ij_counter + 1
               ij_marker(iiB:iiB + block_size - 1, jjB:jjB + block_size - 1) = .FALSE.
               ij_map(1, ij_counter) = iiB
               ij_map(2, ij_counter) = jjB
               ij_map(3, ij_counter) = block_size
               IF (MOD(ij_counter, ngroup) == color_sub) my_ij_pairs = my_ij_pairs + 1
            END DO
         END DO
         DO iiB = 1, homo
            DO jjB = 1, homo_beta
               IF (ij_marker(iiB, jjB)) THEN
                  ij_counter = ij_counter + 1
                  ij_map(1, ij_counter) = iiB
                  ij_map(2, ij_counter) = jjB
                  ij_map(3, ij_counter) = 1
                  IF (MOD(ij_counter, ngroup) == color_sub) my_ij_pairs = my_ij_pairs + 1
               END IF
            END DO
         END DO
         DEALLOCATE (ij_marker)
      END IF

      IF (unit_nr > 0) THEN
         IF (block_size == 1) THEN
            WRITE (UNIT=unit_nr, FMT="(T3,A,T66,F15.1)") &
               "RI_INFO| Percentage of ij pairs communicated with block size 1:", 100.0_dp
         ELSE
            WRITE (UNIT=unit_nr, FMT="(T3,A,T66,F15.1)") &
               "RI_INFO| Percentage of ij pairs communicated with block size 1:", &
               100.0_dp*REAL((total_ij_pairs - assigned_blocks*(block_size**2)), KIND=dp)/REAL(total_ij_pairs, KIND=dp)
         END IF
         CALL m_flush(unit_nr)
      END IF

      CALL timestop(handle)

   END SUBROUTINE mp2_ri_communication

! **************************************************************************************************
!> \brief ...
!> \param para_env ...
!> \param para_env_sub ...
!> \param color_sub ...
!> \param sizes_array ...
!> \param calc_forces ...
!> \param integ_group_size ...
!> \param my_group_L_end ...
!> \param my_group_L_size ...
!> \param my_group_L_size_orig ...
!> \param my_group_L_start ...
!> \param my_new_group_L_size ...
!> \param integ_group_pos2color_sub ...
!> \param sizes_array_orig ...
!> \param ranges_info_array ...
!> \param comm_exchange ...
!> \param comm_rep ...
!> \param num_integ_group ...
! **************************************************************************************************
   SUBROUTINE mp2_ri_create_group(para_env, para_env_sub, color_sub, &
                                  sizes_array, calc_forces, &
                                  integ_group_size, my_group_L_end, &
                                  my_group_L_size, my_group_L_size_orig, my_group_L_start, my_new_group_L_size, &
                                  integ_group_pos2color_sub, &
                                  sizes_array_orig, ranges_info_array, comm_exchange, comm_rep, num_integ_group)
      TYPE(mp_para_env_type), INTENT(IN)                 :: para_env, para_env_sub
      INTEGER, INTENT(IN)                                :: color_sub
      INTEGER, ALLOCATABLE, DIMENSION(:), INTENT(INOUT)  :: sizes_array
      LOGICAL, INTENT(IN)                                :: calc_forces
      INTEGER, INTENT(IN)                                :: integ_group_size, my_group_L_end
      INTEGER, INTENT(INOUT)                             :: my_group_L_size
      INTEGER, INTENT(OUT)                               :: my_group_L_size_orig
      INTEGER, INTENT(IN)                                :: my_group_L_start
      INTEGER, INTENT(INOUT)                             :: my_new_group_L_size
      INTEGER, ALLOCATABLE, DIMENSION(:), INTENT(OUT)    :: integ_group_pos2color_sub, &
                                                            sizes_array_orig
      INTEGER, ALLOCATABLE, DIMENSION(:, :, :), &
         INTENT(OUT)                                     :: ranges_info_array
      TYPE(mp_comm_type), INTENT(OUT)                    :: comm_exchange, comm_rep
      INTEGER, INTENT(IN)                                :: num_integ_group

      CHARACTER(LEN=*), PARAMETER :: routineN = 'mp2_ri_create_group'

      INTEGER                                            :: handle, iiB, proc_receive, proc_shift, &
                                                            sub_sub_color
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: new_sizes_array, rep_ends_array, &
                                                            rep_sizes_array, rep_starts_array
      INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: my_info

      CALL timeset(routineN, handle)
      !
      sub_sub_color = para_env_sub%mepos*num_integ_group + color_sub/integ_group_size
      CALL comm_exchange%from_split(para_env, sub_sub_color)

      ! create the replication group
      sub_sub_color = para_env_sub%mepos*comm_exchange%num_pe + comm_exchange%mepos
      CALL comm_rep%from_split(para_env, sub_sub_color)

      ! create the new limits for K according to the size
      ! of the integral group

      ! info array for replication
      ALLOCATE (rep_ends_array(0:comm_rep%num_pe - 1))
      ALLOCATE (rep_starts_array(0:comm_rep%num_pe - 1))
      ALLOCATE (rep_sizes_array(0:comm_rep%num_pe - 1))

      CALL comm_rep%allgather(my_group_L_size, rep_sizes_array)
      CALL comm_rep%allgather(my_group_L_start, rep_starts_array)
      CALL comm_rep%allgather(my_group_L_end, rep_ends_array)

      ! calculate my_new_group_L_size according to sizes_array
      my_new_group_L_size = my_group_L_size

      ! Info of this process
      ALLOCATE (my_info(4, 0:comm_rep%num_pe - 1))
      my_info(1, 0) = my_group_L_start
      my_info(2, 0) = my_group_L_end
      my_info(3, 0) = 1
      my_info(4, 0) = my_group_L_size

      DO proc_shift = 1, comm_rep%num_pe - 1
         proc_receive = MODULO(comm_rep%mepos - proc_shift, comm_rep%num_pe)

         my_new_group_L_size = my_new_group_L_size + rep_sizes_array(proc_receive)

         my_info(1, proc_shift) = rep_starts_array(proc_receive)
         my_info(2, proc_shift) = rep_ends_array(proc_receive)
         my_info(3, proc_shift) = my_info(4, proc_shift - 1) + 1
         my_info(4, proc_shift) = my_new_group_L_size

      END DO

      ALLOCATE (new_sizes_array(0:comm_exchange%num_pe - 1))
      ALLOCATE (ranges_info_array(4, 0:comm_rep%num_pe - 1, 0:comm_exchange%num_pe - 1))
      CALL comm_exchange%allgather(my_new_group_L_size, new_sizes_array)
      CALL comm_exchange%allgather(my_info, ranges_info_array)

      DEALLOCATE (rep_sizes_array)
      DEALLOCATE (rep_starts_array)
      DEALLOCATE (rep_ends_array)

      ALLOCATE (integ_group_pos2color_sub(0:comm_exchange%num_pe - 1))
      CALL comm_exchange%allgather(color_sub, integ_group_pos2color_sub)

      IF (calc_forces) THEN
         iiB = SIZE(sizes_array)
         ALLOCATE (sizes_array_orig(0:iiB - 1))
         sizes_array_orig(:) = sizes_array
      END IF

      my_group_L_size_orig = my_group_L_size
      my_group_L_size = my_new_group_L_size
      DEALLOCATE (sizes_array)

      ALLOCATE (sizes_array(0:integ_group_size - 1))
      sizes_array(:) = new_sizes_array

      DEALLOCATE (new_sizes_array)
      !
      CALL timestop(handle)

   END SUBROUTINE mp2_ri_create_group

! **************************************************************************************************
!> \brief ...
!> \param mp2_env ...
!> \param para_env ...
!> \param para_env_sub ...
!> \param gd_array ...
!> \param gd_B_virtual ...
!> \param homo ...
!> \param dimen_RI ...
!> \param unit_nr ...
!> \param integ_group_size ...
!> \param ngroup ...
!> \param num_integ_group ...
!> \param virtual ...
!> \param calc_forces ...
! **************************************************************************************************
   SUBROUTINE mp2_ri_get_integ_group_size(mp2_env, para_env, para_env_sub, gd_array, gd_B_virtual, &
                                          homo, dimen_RI, unit_nr, &
                                          integ_group_size, &
                                          ngroup, num_integ_group, &
                                          virtual, calc_forces)
      TYPE(mp2_type)                                     :: mp2_env
      TYPE(mp_para_env_type), INTENT(IN)                 :: para_env, para_env_sub
      TYPE(group_dist_d1_type), INTENT(IN)               :: gd_array
      TYPE(group_dist_d1_type), DIMENSION(:), INTENT(IN) :: gd_B_virtual
      INTEGER, DIMENSION(:), INTENT(IN)                  :: homo
      INTEGER, INTENT(IN)                                :: dimen_RI, unit_nr
      INTEGER, INTENT(OUT)                               :: integ_group_size, ngroup, num_integ_group
      INTEGER, DIMENSION(:), INTENT(IN)                  :: virtual
      LOGICAL, INTENT(IN)                                :: calc_forces

      CHARACTER(LEN=*), PARAMETER :: routineN = 'mp2_ri_get_integ_group_size'

      INTEGER                                            :: block_size, handle, iiB, &
                                                            max_repl_group_size, &
                                                            min_integ_group_size
      INTEGER(KIND=int_8)                                :: mem
      LOGICAL                                            :: calc_group_size
      REAL(KIND=dp)                                      :: factor, mem_base, mem_min, mem_per_blk, &
                                                            mem_per_repl, mem_per_repl_blk, &
                                                            mem_real

      CALL timeset(routineN, handle)

      ngroup = para_env%num_pe/para_env_sub%num_pe

      calc_group_size = mp2_env%ri_mp2%number_integration_groups <= 0
      IF (.NOT. calc_group_size) THEN
         IF (MOD(ngroup, mp2_env%ri_mp2%number_integration_groups) /= 0) calc_group_size = .TRUE.
      END IF

      IF (calc_group_size) THEN
         CALL m_memory(mem)
         mem_real = (mem + 1024*1024 - 1)/(1024*1024)
         CALL para_env%min(mem_real)

         mem_base = 0.0_dp
         mem_per_blk = 0.0_dp
         mem_per_repl = 0.0_dp
         mem_per_repl_blk = 0.0_dp

         ! BIB_C_copy
         mem_per_repl = mem_per_repl + MAXVAL(MAX(REAL(homo, KIND=dp)*maxsize(gd_array), REAL(dimen_RI, KIND=dp))* &
                                              maxsize(gd_B_virtual))*8.0_dp/(1024**2)
         ! BIB_C
         mem_per_repl = mem_per_repl + SUM(REAL(homo, KIND=dp)*maxsize(gd_B_virtual))*maxsize(gd_array)*8.0_dp/(1024**2)
         ! BIB_C_rec
         mem_per_repl_blk = mem_per_repl_blk + REAL(MAXVAL(maxsize(gd_B_virtual)), KIND=dp)*maxsize(gd_array)*8.0_dp/(1024**2)
         ! local_i_aL+local_j_aL
         mem_per_blk = mem_per_blk + 2.0_dp*MAXVAL(maxsize(gd_B_virtual))*REAL(dimen_RI, KIND=dp)*8.0_dp/(1024**2)
         ! local_ab
         mem_base = mem_base + MAXVAL(REAL(virtual, KIND=dp)*maxsize(gd_B_virtual))*8.0_dp/(1024**2)
         ! external_ab/external_i_aL
         mem_base = mem_base + REAL(MAX(dimen_RI, MAXVAL(virtual)), KIND=dp)*MAXVAL(maxsize(gd_B_virtual))*8.0_dp/(1024**2)

         IF (calc_forces) THEN
            ! Gamma_P_ia
            mem_per_repl = mem_per_repl + SUM(REAL(homo, KIND=dp)*maxsize(gd_array)* &
                                              maxsize(gd_B_virtual))*8.0_dp/(1024**2)
            ! Y_i_aP+Y_j_aP
            mem_per_blk = mem_per_blk + 2.0_dp*MAXVAL(maxsize(gd_B_virtual))*dimen_RI*8.0_dp/(1024**2)
            ! local_ba/t_ab
            mem_base = mem_base + REAL(MAXVAL(maxsize(gd_B_virtual)), KIND=dp)*MAX(dimen_RI, MAXVAL(virtual))*8.0_dp/(1024**2)
            ! P_ij
            mem_base = mem_base + SUM(REAL(homo, KIND=dp)*homo)*8.0_dp/(1024**2)
            ! P_ab
            mem_base = mem_base + SUM(REAL(virtual, KIND=dp)*maxsize(gd_B_virtual))*8.0_dp/(1024**2)
            ! send_ab/send_i_aL
            mem_base = mem_base + REAL(MAX(dimen_RI, MAXVAL(virtual)), KIND=dp)*MAXVAL(maxsize(gd_B_virtual))*8.0_dp/(1024**2)
         END IF

         ! This a first guess based on the assumption of optimal block sizes
         block_size = MAX(1, MIN(FLOOR(SQRT(REAL(MINVAL(homo), KIND=dp))), FLOOR(MINVAL(homo)/SQRT(2.0_dp*ngroup))))
         IF (mp2_env%ri_mp2%block_size > 0) block_size = mp2_env%ri_mp2%block_size

         mem_min = mem_base + mem_per_repl + (mem_per_blk + mem_per_repl_blk)*block_size

         IF (unit_nr > 0) WRITE (unit_nr, '(T3,A,T68,F9.2,A4)') 'RI_INFO| Minimum available memory per MPI process:', &
            mem_real, ' MiB'
         IF (unit_nr > 0) WRITE (unit_nr, '(T3,A,T68,F9.2,A4)') 'RI_INFO| Minimum required memory per MPI process:', &
            mem_min, ' MiB'

         ! We use the following communication model
         ! Comm(replication)+Comm(collection of data for ij pair)+Comm(contraction)
         ! One can show that the costs of the contraction step are independent of the block size and the replication group size
         ! With gradients, the other two steps are carried out twice (Y_i_aP -> Gamma_i_aP, and dereplication)
         ! NL ... number of RI basis functions
         ! NR ... replication group size
         ! NG ... number of sub groups
         ! NB ... Block size
         ! o  ... number of occupied orbitals
         ! Then, we have the communication costs (in multiples of the original BIb_C matrix)
         ! (NR/NG)+(1-(NR/NG))*(o/NB+NB-2)/NG = (NR/NG)*(1-(o/NB+NB-2)/NG)+(o/NB+NB-2)/NG
         ! and with gradients
         ! 2*(NR/NG)+2*(1-(NR/NG))*(o/NB+NB-2)/NG = (NR/NG)*(1-(o/NB+NB-2)/NG)+(o/NB+NB-2)/NG
         ! We are looking for the minimum of the communication volume,
         ! thus, if the prefactor of (NR/NG) is smaller than zero, use the largest possible replication group size.
         ! If the factor is larger than zero, set the replication group size to 1. (For small systems and a large number of subgroups)
         ! Replication group size = 1 implies that the integration group size equals the number of subgroups

         integ_group_size = ngroup

         ! Multiply everything by homo*virtual to consider differences between spin channels in case of open-shell calculations
         factor = REAL(SUM(homo*virtual), KIND=dp) &
                  - SUM((REAL(MAXVAL(homo), KIND=dp)/block_size + block_size - 2.0_dp)*homo*virtual)/ngroup
         IF (SIZE(homo) == 2) factor = factor - 2.0_dp*PRODUCT(homo)/block_size/ngroup*SUM(homo*virtual)

         IF (factor <= 0.0_dp) THEN
            ! Remove the fixed memory and divide by the memory per replication group size
            max_repl_group_size = MIN(MAX(FLOOR((mem_real - mem_base - mem_per_blk*block_size)/ &
                                                (mem_per_repl + mem_per_repl_blk*block_size)), 1), ngroup)
            ! Convert to an integration group size
            min_integ_group_size = ngroup/max_repl_group_size

            ! Ensure that the integration group size is a divisor of the number of sub groups
            DO iiB = MAX(MIN(min_integ_group_size, ngroup), 1), ngroup
               ! check that the ngroup is a multiple of  integ_group_size
               IF (MOD(ngroup, iiB) == 0) THEN
                  integ_group_size = iiB
                  EXIT
               END IF
               integ_group_size = integ_group_size + 1
            END DO
         END IF
      ELSE ! We take the user provided group size
         integ_group_size = ngroup/mp2_env%ri_mp2%number_integration_groups
      END IF

      IF (unit_nr > 0) THEN
         WRITE (UNIT=unit_nr, FMT="(T3,A,T75,i6)") &
            "RI_INFO| Group size for integral replication:", integ_group_size*para_env_sub%num_pe
         CALL m_flush(unit_nr)
      END IF

      num_integ_group = ngroup/integ_group_size

      CALL timestop(handle)

   END SUBROUTINE mp2_ri_get_integ_group_size

! **************************************************************************************************
!> \brief ...
!> \param mp2_env ...
!> \param para_env ...
!> \param para_env_sub ...
!> \param gd_array ...
!> \param gd_B_virtual ...
!> \param homo ...
!> \param virtual ...
!> \param dimen_RI ...
!> \param unit_nr ...
!> \param block_size ...
!> \param ngroup ...
!> \param num_integ_group ...
!> \param my_open_shell_ss ...
!> \param calc_forces ...
!> \param buffer_1D ...
! **************************************************************************************************
   SUBROUTINE mp2_ri_get_block_size(mp2_env, para_env, para_env_sub, gd_array, gd_B_virtual, &
                                    homo, virtual, dimen_RI, unit_nr, &
                                    block_size, ngroup, num_integ_group, &
                                    my_open_shell_ss, calc_forces, buffer_1D)
      TYPE(mp2_type)                                     :: mp2_env
      TYPE(mp_para_env_type), INTENT(IN)                 :: para_env, para_env_sub
      TYPE(group_dist_d1_type), INTENT(IN)               :: gd_array
      TYPE(group_dist_d1_type), DIMENSION(:), INTENT(IN) :: gd_B_virtual
      INTEGER, DIMENSION(:), INTENT(IN)                  :: homo, virtual
      INTEGER, INTENT(IN)                                :: dimen_RI, unit_nr
      INTEGER, INTENT(OUT)                               :: block_size, ngroup
      INTEGER, INTENT(IN)                                :: num_integ_group
      LOGICAL, INTENT(IN)                                :: my_open_shell_ss, calc_forces
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), &
         INTENT(OUT)                                     :: buffer_1D

      CHARACTER(LEN=*), PARAMETER :: routineN = 'mp2_ri_get_block_size'

      INTEGER                                            :: best_block_size, handle, num_IJ_blocks
      INTEGER(KIND=int_8)                                :: buffer_size, mem
      REAL(KIND=dp)                                      :: mem_base, mem_per_blk, mem_per_repl_blk, &
                                                            mem_real

      CALL timeset(routineN, handle)

      ngroup = para_env%num_pe/para_env_sub%num_pe

      CALL m_memory(mem)
      mem_real = (mem + 1024*1024 - 1)/(1024*1024)
      CALL para_env%min(mem_real)

      mem_base = 0.0_dp
      mem_per_blk = 0.0_dp
      mem_per_repl_blk = 0.0_dp

      ! external_ab
      mem_base = mem_base + MAXVAL(maxsize(gd_B_virtual))*MAX(dimen_RI, MAXVAL(virtual))*8.0_dp/(1024**2)
      ! BIB_C_rec
      mem_per_repl_blk = mem_per_repl_blk + REAL(MAXVAL(maxsize(gd_B_virtual)), KIND=dp)*maxsize(gd_array)*8.0_dp/(1024**2)
      ! local_i_aL+local_j_aL
      mem_per_blk = mem_per_blk + 2.0_dp*MAXVAL(maxsize(gd_B_virtual))*REAL(dimen_RI, KIND=dp)*8.0_dp/(1024**2)
      ! Copy to keep arrays contiguous
      mem_base = mem_base + MAXVAL(maxsize(gd_B_virtual))*MAX(dimen_RI, MAXVAL(virtual))*8.0_dp/(1024**2)

      IF (calc_forces) THEN
         ! Y_i_aP+Y_j_aP+BIb_C_send
         mem_per_blk = mem_per_blk + 3.0_dp*MAXVAL(maxsize(gd_B_virtual))*dimen_RI*8.0_dp/(1024**2)
         ! send_ab
         mem_base = mem_base + MAXVAL(maxsize(gd_B_virtual))*MAX(dimen_RI, MAXVAL(virtual))*8.0_dp/(1024**2)
      END IF

      best_block_size = 1

      ! Here we split the memory half for the communication, half for replication
      IF (mp2_env%ri_mp2%block_size > 0) THEN
         best_block_size = mp2_env%ri_mp2%block_size
      ELSE
         best_block_size = MAX(FLOOR((mem_real - mem_base)/(mem_per_blk + mem_per_repl_blk*ngroup/num_integ_group)), 1)

         DO
            IF (SIZE(homo) == 1) THEN
            IF (.NOT. my_open_shell_ss) THEN
               num_IJ_blocks = (homo(1)/best_block_size)
               num_IJ_blocks = (num_IJ_blocks*num_IJ_blocks - num_IJ_blocks)/2
            ELSE
               num_IJ_blocks = ((homo(1) - 1)/best_block_size)
               num_IJ_blocks = (num_IJ_blocks*num_IJ_blocks - num_IJ_blocks)/2
            END IF
            ELSE
            num_ij_blocks = PRODUCT(homo/best_block_size)
            END IF
            ! Enforce at least one large block for each subgroup
            IF ((num_IJ_blocks >= ngroup .AND. num_IJ_blocks > 0) .OR. best_block_size == 1) THEN
               EXIT
            ELSE
               best_block_size = best_block_size - 1
            END IF
         END DO

         IF (SIZE(homo) == 1) THEN
         IF (my_open_shell_ss) THEN
            ! check that best_block_size is not bigger than sqrt(homo-1)
            ! Diagonal elements do not have to be considered
            best_block_size = MIN(FLOOR(SQRT(REAL(homo(1) - 1, KIND=dp))), best_block_size)
         ELSE
            ! check that best_block_size is not bigger than sqrt(homo)
            best_block_size = MIN(FLOOR(SQRT(REAL(homo(1), KIND=dp))), best_block_size)
         END IF
         END IF
      END IF
      block_size = MAX(1, best_block_size)

      IF (unit_nr > 0) THEN
         WRITE (UNIT=unit_nr, FMT="(T3,A,T75,i6)") &
            "RI_INFO| Block size:", block_size
         CALL m_flush(unit_nr)
      END IF

      ! Determine recv buffer size (BI_C_recv, external_i_aL, external_ab)
      buffer_size = MAX(INT(maxsize(gd_array), KIND=int_8)*block_size, INT(MAX(dimen_RI, MAXVAL(virtual)), KIND=int_8)) &
                    *MAXVAL(maxsize(gd_B_virtual))
      ! The send buffer has the same size as the recv buffer
      IF (calc_forces) buffer_size = buffer_size*2
      ALLOCATE (buffer_1D(buffer_size))

      CALL timestop(handle)

   END SUBROUTINE mp2_ri_get_block_size

! **************************************************************************************************
!> \brief ...
!> \param mp2_env ...
!> \param para_env_sub ...
!> \param gd_B_virtual ...
!> \param Eigenval ...
!> \param homo ...
!> \param dimen_RI ...
!> \param iiB ...
!> \param jjB ...
!> \param my_B_size ...
!> \param my_B_virtual_end ...
!> \param my_B_virtual_start ...
!> \param my_i ...
!> \param my_j ...
!> \param virtual ...
!> \param local_ab ...
!> \param t_ab ...
!> \param my_local_i_aL ...
!> \param my_local_j_aL ...
!> \param open_ss ...
!> \param Y_i_aP ...
!> \param Y_j_aP ...
!> \param local_ba ...
!> \param ispin ...
!> \param jspin ...
!> \param dgemm_counter ...
!> \param buffer_1D ...
! **************************************************************************************************
   SUBROUTINE mp2_update_P_gamma(mp2_env, para_env_sub, gd_B_virtual, &
                                 Eigenval, homo, dimen_RI, iiB, jjB, my_B_size, &
                                 my_B_virtual_end, my_B_virtual_start, my_i, my_j, virtual, local_ab, &
                                 t_ab, my_local_i_aL, my_local_j_aL, open_ss, Y_i_aP, Y_j_aP, &
                                 local_ba, ispin, jspin, dgemm_counter, buffer_1D)
      TYPE(mp2_type)                                     :: mp2_env
      TYPE(mp_para_env_type), INTENT(IN)                 :: para_env_sub
      TYPE(group_dist_d1_type), DIMENSION(:), INTENT(IN) :: gd_B_virtual
      REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: Eigenval
      INTEGER, DIMENSION(:), INTENT(IN)                  :: homo
      INTEGER, INTENT(IN)                                :: dimen_RI, iiB, jjB
      INTEGER, DIMENSION(:), INTENT(IN)                  :: my_B_size, my_B_virtual_end, &
                                                            my_B_virtual_start
      INTEGER, INTENT(IN)                                :: my_i, my_j
      INTEGER, DIMENSION(:), INTENT(IN)                  :: virtual
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
         INTENT(INOUT), TARGET                           :: local_ab
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
         INTENT(IN), TARGET                              :: t_ab, my_local_i_aL, my_local_j_aL
      LOGICAL, INTENT(IN)                                :: open_ss
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
         INTENT(INOUT), TARGET                           :: Y_i_aP, Y_j_aP, local_ba
      INTEGER, INTENT(IN)                                :: ispin, jspin
      TYPE(dgemm_counter_type), INTENT(INOUT)            :: dgemm_counter
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:), TARGET    :: buffer_1D

      CHARACTER(LEN=*), PARAMETER :: routineN = 'mp2_update_P_gamma'

      INTEGER :: a, b, b_global, handle, proc_receive, proc_send, proc_shift, rec_B_size, &
         rec_B_virtual_end, rec_B_virtual_start, send_B_size, send_B_virtual_end, &
         send_B_virtual_start
      INTEGER(KIND=int_8)                                :: offset
      LOGICAL                                            :: alpha_beta
      REAL(KIND=dp)                                      :: factor, P_ij_diag
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :), &
         POINTER                                         :: external_ab, send_ab

      CALL timeset(routineN//"_Pia", handle)

      alpha_beta = .NOT. (ispin == jspin)
      IF (open_ss) THEN
         factor = 1.0_dp
      ELSE
         factor = 2.0_dp
      END IF
      ! divide the (ia|jb) integrals by Delta_ij^ab
      DO b = 1, my_B_size(jspin)
         b_global = b + my_B_virtual_start(jspin) - 1
         DO a = 1, virtual(ispin)
            local_ab(a, b) = -local_ab(a, b)/ &
                             (Eigenval(homo(ispin) + a, ispin) + Eigenval(homo(jspin) + b_global, jspin) - &
                              Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, jspin))
         END DO
      END DO
      IF (.NOT. (alpha_beta)) THEN
         P_ij_diag = -SUM(local_ab*t_ab)*factor
      ELSE
         ! update diagonal part of P_ij
         P_ij_diag = -SUM(local_ab*local_ab)*mp2_env%scale_S
         ! More integrals needed only for alpha-beta case: local_ba
         DO b = 1, my_B_size(ispin)
            b_global = b + my_B_virtual_start(ispin) - 1
            DO a = 1, virtual(jspin)
               local_ba(a, b) = -local_ba(a, b)/ &
                                (Eigenval(homo(jspin) + a, jspin) + Eigenval(homo(ispin) + b_global, ispin) - &
                                 Eigenval(my_i + iiB - 1, ispin) - Eigenval(my_j + jjB - 1, jspin))
            END DO
         END DO
      END IF

      ! P_ab and add diagonal part of P_ij

      CALL dgemm_counter_start(dgemm_counter)
      IF (.NOT. (alpha_beta)) THEN
         CALL mp2_env%local_gemm_ctx%gemm('T', 'N', my_B_size(ispin), my_B_size(ispin), virtual(ispin), 1.0_dp, &
                                          t_ab, virtual(ispin), local_ab, virtual(ispin), &
                                          1.0_dp, mp2_env%ri_grad%P_ab(ispin)%array(:, &
                                                               my_B_virtual_start(ispin):my_B_virtual_end(ispin)), my_B_size(ispin))
         mp2_env%ri_grad%P_ij(ispin)%array(my_i + iiB - 1, my_i + iiB - 1) = &
            mp2_env%ri_grad%P_ij(ispin)%array(my_i + iiB - 1, my_i + iiB - 1) + P_ij_diag
      ELSE
         CALL mp2_env%local_gemm_ctx%gemm('T', 'N', my_B_size(ispin), my_B_size(ispin), virtual(jspin), mp2_env%scale_S, &
                                          local_ba, virtual(jspin), local_ba, virtual(jspin), 1.0_dp, &
                          mp2_env%ri_grad%P_ab(ispin)%array(:, my_B_virtual_start(ispin):my_B_virtual_end(ispin)), my_B_size(ispin))

         mp2_env%ri_grad%P_ij(ispin)%array(my_i + iiB - 1, my_i + iiB - 1) = &
            mp2_env%ri_grad%P_ij(ispin)%array(my_i + iiB - 1, my_i + iiB - 1) + P_ij_diag

         CALL mp2_env%local_gemm_ctx%gemm('T', 'N', my_B_size(jspin), my_B_size(jspin), virtual(ispin), mp2_env%scale_S, &
                                          local_ab, virtual(ispin), local_ab, virtual(ispin), 1.0_dp, &
                          mp2_env%ri_grad%P_ab(jspin)%array(:, my_B_virtual_start(jspin):my_B_virtual_end(jspin)), my_B_size(jspin))

         mp2_env%ri_grad%P_ij(jspin)%array(my_j + jjB - 1, my_j + jjB - 1) = &
            mp2_env%ri_grad%P_ij(jspin)%array(my_j + jjB - 1, my_j + jjB - 1) + P_ij_diag
      END IF
      ! The summation is over unique pairs. In alpha-beta case, all pairs are unique: subroutine is called for
      ! both i^alpha,j^beta and i^beta,j^alpha. Formally, my_i can be equal to my_j, but they are different
      ! due to spin in alpha-beta case.
      IF ((my_i /= my_j) .AND. (.NOT. alpha_beta)) THEN

         CALL mp2_env%local_gemm_ctx%gemm('N', 'T', my_B_size(ispin), virtual(ispin), my_B_size(ispin), 1.0_dp, &
                                          t_ab(my_B_virtual_start(ispin):my_B_virtual_end(ispin), :), my_B_size(ispin), &
                                          local_ab, virtual(ispin), &
                                          1.0_dp, mp2_env%ri_grad%P_ab(ispin)%array, my_B_size(ispin))

         mp2_env%ri_grad%P_ij(ispin)%array(my_j + jjB - 1, my_j + jjB - 1) = &
            mp2_env%ri_grad%P_ij(ispin)%array(my_j + jjB - 1, my_j + jjB - 1) + P_ij_diag
      END IF
      DO proc_shift = 1, para_env_sub%num_pe - 1
         proc_send = MODULO(para_env_sub%mepos + proc_shift, para_env_sub%num_pe)
         proc_receive = MODULO(para_env_sub%mepos - proc_shift, para_env_sub%num_pe)

         CALL get_group_dist(gd_B_virtual(jspin), proc_receive, rec_B_virtual_start, rec_B_virtual_end, rec_B_size)
         CALL get_group_dist(gd_B_virtual(jspin), proc_send, send_B_virtual_start, send_B_virtual_end, send_B_size)

         external_ab(1:virtual(ispin), 1:rec_B_size) => buffer_1D(1:INT(virtual(ispin), int_8)*rec_B_size)
         external_ab = 0.0_dp

         CALL para_env_sub%sendrecv(local_ab, proc_send, &
                                    external_ab, proc_receive)

         IF (.NOT. (alpha_beta)) THEN
            CALL mp2_env%local_gemm_ctx%gemm('T', 'N', my_B_size(ispin), rec_B_size, virtual(ispin), 1.0_dp, &
                                             t_ab, virtual(ispin), external_ab, virtual(ispin), &
                              1.0_dp, mp2_env%ri_grad%P_ab(ispin)%array(:, rec_B_virtual_start:rec_B_virtual_end), my_B_size(ispin))
         ELSE
            CALL mp2_env%local_gemm_ctx%gemm('T', 'N', my_B_size(jspin), rec_B_size, virtual(ispin), mp2_env%scale_S, &
                                             local_ab, virtual(ispin), external_ab, virtual(ispin), &
                                             1.0_dp, mp2_env%ri_grad%P_ab(jspin)%array(:, rec_B_virtual_start:rec_B_virtual_end), &
                                             my_B_size(jspin))

            ! For alpha-beta part of alpha-density we need a new parallel code
            ! And new external_ab (of a different size)
            CALL get_group_dist(gd_B_virtual(ispin), proc_receive, rec_B_virtual_start, rec_B_virtual_end, rec_B_size)
            CALL get_group_dist(gd_B_virtual(ispin), proc_send, send_B_virtual_start, send_B_virtual_end, send_B_size)
            external_ab(1:virtual(jspin), 1:rec_B_size) => buffer_1D(1:INT(virtual(jspin), int_8)*rec_B_size)
            external_ab = 0.0_dp
            CALL para_env_sub%sendrecv(local_ba, proc_send, &
                                       external_ab, proc_receive)
            CALL mp2_env%local_gemm_ctx%gemm('T', 'N', my_B_size(ispin), rec_B_size, virtual(jspin), mp2_env%scale_S, &
                                             local_ba, virtual(jspin), external_ab, virtual(jspin), &
                              1.0_dp, mp2_env%ri_grad%P_ab(ispin)%array(:, rec_B_virtual_start:rec_B_virtual_end), my_B_size(ispin))
         END IF

         IF ((my_i /= my_j) .AND. (.NOT. alpha_beta)) THEN
            external_ab(1:my_B_size(ispin), 1:virtual(ispin)) => &
               buffer_1D(1:INT(virtual(ispin), int_8)*my_B_size(ispin))
            external_ab = 0.0_dp

            offset = INT(virtual(ispin), int_8)*my_B_size(ispin)

            send_ab(1:send_B_size, 1:virtual(ispin)) => buffer_1D(offset + 1:offset + INT(send_B_size, int_8)*virtual(ispin))
            send_ab = 0.0_dp

            CALL mp2_env%local_gemm_ctx%gemm('N', 'T', send_B_size, virtual(ispin), my_B_size(ispin), 1.0_dp, &
                                             t_ab(send_B_virtual_start:send_B_virtual_end, :), send_B_size, &
                                             local_ab, virtual(ispin), 0.0_dp, send_ab, send_B_size)
            CALL para_env_sub%sendrecv(send_ab, proc_send, &
                                       external_ab, proc_receive)

            mp2_env%ri_grad%P_ab(ispin)%array(:, :) = mp2_env%ri_grad%P_ab(ispin)%array + external_ab
         END IF

      END DO
      IF (.NOT. alpha_beta) THEN
         IF (my_i /= my_j) THEN
            CALL dgemm_counter_stop(dgemm_counter, 2*my_B_size(ispin), virtual(ispin), virtual(ispin))
         ELSE
            CALL dgemm_counter_stop(dgemm_counter, my_B_size(ispin), virtual(ispin), virtual(ispin))
         END IF
      ELSE
         CALL dgemm_counter_stop(dgemm_counter, SUM(my_B_size), virtual(ispin), virtual(jspin))
      END IF
      CALL timestop(handle)

      ! Now, Gamma_P_ia (made of Y_ia_P)

      CALL timeset(routineN//"_Gamma", handle)
      CALL dgemm_counter_start(dgemm_counter)
      IF (.NOT. alpha_beta) THEN
         ! Alpha-alpha, beta-beta and closed shell
         CALL mp2_env%local_gemm_ctx%gemm('N', 'T', my_B_size(ispin), dimen_RI, my_B_size(ispin), 1.0_dp, &
                                          t_ab(my_B_virtual_start(ispin):my_B_virtual_end(ispin), :), my_B_size(ispin), &
                                          my_local_j_aL, dimen_RI, 1.0_dp, Y_i_aP, my_B_size(ispin))
      ELSE ! Alpha-beta
         CALL mp2_env%local_gemm_ctx%gemm('N', 'T', my_B_size(ispin), dimen_RI, my_B_size(jspin), mp2_env%scale_S, &
                                          local_ab(my_B_virtual_start(ispin):my_B_virtual_end(ispin), :), my_B_size(ispin), &
                                          my_local_j_aL, dimen_RI, 1.0_dp, Y_i_aP, my_B_size(ispin))
         CALL mp2_env%local_gemm_ctx%gemm('T', 'T', my_B_size(jspin), dimen_RI, my_B_size(ispin), mp2_env%scale_S, &
                                          local_ab(my_B_virtual_start(ispin):my_B_virtual_end(ispin), :), my_B_size(ispin), &
                                          my_local_i_aL, dimen_RI, 1.0_dp, Y_j_aP, my_B_size(jspin))
      END IF

      IF (para_env_sub%num_pe > 1) THEN
         external_ab(1:my_B_size(ispin), 1:dimen_RI) => buffer_1D(1:INT(my_B_size(ispin), int_8)*dimen_RI)
         external_ab = 0.0_dp

         offset = INT(my_B_size(ispin), int_8)*dimen_RI
      END IF
      !
      DO proc_shift = 1, para_env_sub%num_pe - 1
         proc_send = MODULO(para_env_sub%mepos + proc_shift, para_env_sub%num_pe)
         proc_receive = MODULO(para_env_sub%mepos - proc_shift, para_env_sub%num_pe)

         CALL get_group_dist(gd_B_virtual(ispin), proc_receive, rec_B_virtual_start, rec_B_virtual_end, rec_B_size)
         CALL get_group_dist(gd_B_virtual(ispin), proc_send, send_B_virtual_start, send_B_virtual_end, send_B_size)

         send_ab(1:send_B_size, 1:dimen_RI) => buffer_1D(offset + 1:offset + INT(dimen_RI, int_8)*send_B_size)
         send_ab = 0.0_dp
         IF (.NOT. alpha_beta) THEN
            CALL mp2_env%local_gemm_ctx%gemm('N', 'T', send_B_size, dimen_RI, my_B_size(ispin), 1.0_dp, &
                                             t_ab(send_B_virtual_start:send_B_virtual_end, :), send_B_size, &
                                             my_local_j_aL, dimen_RI, 0.0_dp, send_ab, send_B_size)
            CALL para_env_sub%sendrecv(send_ab, proc_send, external_ab, proc_receive)

            Y_i_aP(:, :) = Y_i_aP + external_ab

         ELSE ! Alpha-beta case
            ! Alpha-alpha part
            CALL mp2_env%local_gemm_ctx%gemm('N', 'T', send_B_size, dimen_RI, my_B_size(jspin), mp2_env%scale_S, &
                                             local_ab(send_B_virtual_start:send_B_virtual_end, :), send_B_size, &
                                             my_local_j_aL, dimen_RI, 0.0_dp, send_ab, send_B_size)
            CALL para_env_sub%sendrecv(send_ab, proc_send, external_ab, proc_receive)
            Y_i_aP(:, :) = Y_i_aP + external_ab
         END IF
      END DO

      IF (alpha_beta) THEN
         ! For beta-beta part (in alpha-beta case) we need a new parallel code
         IF (para_env_sub%num_pe > 1) THEN
            external_ab(1:my_B_size(jspin), 1:dimen_RI) => buffer_1D(1:INT(my_B_size(jspin), int_8)*dimen_RI)
            external_ab = 0.0_dp

            offset = INT(my_B_size(jspin), int_8)*dimen_RI
         END IF
         DO proc_shift = 1, para_env_sub%num_pe - 1
            proc_send = MODULO(para_env_sub%mepos + proc_shift, para_env_sub%num_pe)
            proc_receive = MODULO(para_env_sub%mepos - proc_shift, para_env_sub%num_pe)

            CALL get_group_dist(gd_B_virtual(jspin), proc_send, send_B_virtual_start, send_B_virtual_end, send_B_size)
            send_ab(1:send_B_size, 1:dimen_RI) => buffer_1D(offset + 1:offset + INT(dimen_RI, int_8)*send_B_size)
            send_ab = 0.0_dp
            CALL mp2_env%local_gemm_ctx%gemm('N', 'T', send_B_size, dimen_RI, my_B_size(ispin), mp2_env%scale_S, &
                                             local_ba(send_B_virtual_start:send_B_virtual_end, :), send_B_size, &
                                             my_local_i_aL, dimen_RI, 0.0_dp, send_ab, send_B_size)
            CALL para_env_sub%sendrecv(send_ab, proc_send, external_ab, proc_receive)
            Y_j_aP(:, :) = Y_j_aP + external_ab

         END DO

         ! Here, we just use approximate bounds. For large systems virtual(ispin) is approx virtual(jspin), same for B_size
         CALL dgemm_counter_stop(dgemm_counter, 3*virtual(ispin), dimen_RI, my_B_size(jspin))
      ELSE
         CALL dgemm_counter_stop(dgemm_counter, virtual(ispin), dimen_RI, my_B_size(ispin))
      END IF

      IF ((my_i /= my_j) .AND. (.NOT. alpha_beta)) THEN
         ! Alpha-alpha, beta-beta and closed shell
         CALL dgemm_counter_start(dgemm_counter)
         CALL mp2_env%local_gemm_ctx%gemm('T', 'T', my_B_size(ispin), dimen_RI, my_B_size(ispin), 1.0_dp, &
                                          t_ab(my_B_virtual_start(ispin):my_B_virtual_end(ispin), :), my_B_size(ispin), &
                                          my_local_i_aL, dimen_RI, 1.0_dp, Y_j_aP, my_B_size(ispin))
         DO proc_shift = 1, para_env_sub%num_pe - 1
            proc_send = MODULO(para_env_sub%mepos + proc_shift, para_env_sub%num_pe)
            proc_receive = MODULO(para_env_sub%mepos - proc_shift, para_env_sub%num_pe)

            CALL get_group_dist(gd_B_virtual(ispin), proc_receive, rec_B_virtual_start, rec_B_virtual_end, rec_B_size)

            external_ab(1:dimen_RI, 1:rec_B_size) => buffer_1D(1:INT(dimen_RI, int_8)*rec_B_size)
            external_ab = 0.0_dp

            CALL para_env_sub%sendrecv(my_local_i_aL, proc_send, &
                                       external_ab, proc_receive)

            ! Alpha-alpha, beta-beta and closed shell
            CALL mp2_env%local_gemm_ctx%gemm('T', 'T', my_B_size(ispin), dimen_RI, rec_B_size, 1.0_dp, &
                                             t_ab(rec_B_virtual_start:rec_B_virtual_end, :), rec_B_size, &
                                             external_ab, dimen_RI, 1.0_dp, Y_j_aP, my_B_size(ispin))
         END DO

         CALL dgemm_counter_stop(dgemm_counter, my_B_size(ispin), dimen_RI, virtual(ispin))
      END IF

      CALL timestop(handle)
   END SUBROUTINE mp2_update_P_gamma

! **************************************************************************************************
!> \brief ...
!> \param Gamma_P_ia ...
!> \param ij_index ...
!> \param my_B_size ...
!> \param my_block_size ...
!> \param my_group_L_size ...
!> \param my_i ...
!> \param my_ij_pairs ...
!> \param ngroup ...
!> \param num_integ_group ...
!> \param integ_group_pos2color_sub ...
!> \param num_ij_pairs ...
!> \param ij_map ...
!> \param ranges_info_array ...
!> \param Y_i_aP ...
!> \param comm_exchange ...
!> \param sizes_array ...
!> \param spin ...
!> \param buffer_1D ...
! **************************************************************************************************
   SUBROUTINE mp2_redistribute_gamma(Gamma_P_ia, ij_index, my_B_size, &
                                     my_block_size, my_group_L_size, my_i, my_ij_pairs, ngroup, &
                                     num_integ_group, integ_group_pos2color_sub, num_ij_pairs, &
                                     ij_map, ranges_info_array, Y_i_aP, comm_exchange, &
                                     sizes_array, spin, buffer_1D)

      REAL(KIND=dp), DIMENSION(:, :, :), INTENT(INOUT)   :: Gamma_P_ia
      INTEGER, INTENT(IN)                                :: ij_index, my_B_size, my_block_size, &
                                                            my_group_L_size, my_i, my_ij_pairs, &
                                                            ngroup, num_integ_group
      INTEGER, ALLOCATABLE, DIMENSION(:), INTENT(IN)     :: integ_group_pos2color_sub, num_ij_pairs
      INTEGER, ALLOCATABLE, DIMENSION(:, :), INTENT(IN)  :: ij_map
      INTEGER, ALLOCATABLE, DIMENSION(:, :, :), &
         INTENT(IN)                                      :: ranges_info_array
      REAL(KIND=dp), DIMENSION(:, :, :), INTENT(IN)      :: Y_i_aP
      TYPE(mp_comm_type), INTENT(IN)                     :: comm_exchange
      INTEGER, ALLOCATABLE, DIMENSION(:), INTENT(IN)     :: sizes_array
      INTEGER, INTENT(IN)                                :: spin
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:), TARGET    :: buffer_1D

      CHARACTER(LEN=*), PARAMETER :: routineN = 'mp2_redistribute_gamma'

      INTEGER :: end_point, handle, handle2, iiB, ij_counter_rec, irep, kkk, lll, Lstart_pos, &
         proc_receive, proc_send, proc_shift, rec_i, rec_ij_index, send_L_size, start_point, tag
      INTEGER(KIND=int_8)                                :: offset
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:, :, :), &
         POINTER                                         :: BI_C_rec, BI_C_send

! In alpha-beta case Y_i_aP_beta is sent as Y_j_aP

      CALL timeset(routineN//"_comm2", handle)

      tag = 43

      IF (ij_index <= my_ij_pairs) THEN
         ! somethig to send
         ! start with myself
         CALL timeset(routineN//"_comm2_w", handle2)
         DO irep = 0, num_integ_group - 1
            Lstart_pos = ranges_info_array(1, irep, comm_exchange%mepos)
            start_point = ranges_info_array(3, irep, comm_exchange%mepos)
            end_point = ranges_info_array(4, irep, comm_exchange%mepos)
!$OMP PARALLEL DO DEFAULT(NONE) &
!$OMP             PRIVATE(kkk,lll,iiB) &
!$OMP             SHARED(start_point,end_point,Lstart_pos,my_block_size,&
!$OMP                    Gamma_P_ia,my_i,my_B_size,Y_i_aP)
            DO kkk = start_point, end_point
               lll = kkk - start_point + Lstart_pos
               DO iiB = 1, my_block_size
                  Gamma_P_ia(1:my_B_size, my_i + iiB - 1, kkk) = &
                     Gamma_P_ia(1:my_B_size, my_i + iiB - 1, kkk) + &
                     Y_i_aP(1:my_B_size, lll, iiB)
               END DO
            END DO
!$OMP END PARALLEL DO
         END DO
         CALL timestop(handle2)

         ! Y_i_aP(my_B_size,dimen_RI,block_size)

         DO proc_shift = 1, comm_exchange%num_pe - 1
            proc_send = MODULO(comm_exchange%mepos + proc_shift, comm_exchange%num_pe)
            proc_receive = MODULO(comm_exchange%mepos - proc_shift, comm_exchange%num_pe)

            send_L_size = sizes_array(proc_send)
            BI_C_send(1:my_B_size, 1:my_block_size, 1:send_L_size) => &
               buffer_1D(1:INT(my_B_size, int_8)*my_block_size*send_L_size)

            offset = INT(my_B_size, int_8)*my_block_size*send_L_size

            CALL timeset(routineN//"_comm2_w", handle2)
            BI_C_send = 0.0_dp
            DO irep = 0, num_integ_group - 1
               Lstart_pos = ranges_info_array(1, irep, proc_send)
               start_point = ranges_info_array(3, irep, proc_send)
               end_point = ranges_info_array(4, irep, proc_send)
!$OMP PARALLEL DO DEFAULT(NONE) &
!$OMP             PRIVATE(kkk,lll,iiB) &
!$OMP             SHARED(start_point,end_point,Lstart_pos,my_block_size,&
!$OMP                    BI_C_send,my_B_size,Y_i_aP)
               DO kkk = start_point, end_point
                  lll = kkk - start_point + Lstart_pos
                  DO iiB = 1, my_block_size
                     BI_C_send(1:my_B_size, iiB, kkk) = Y_i_aP(1:my_B_size, lll, iiB)
                  END DO
               END DO
!$OMP END PARALLEL DO
            END DO
            CALL timestop(handle2)

            rec_ij_index = num_ij_pairs(proc_receive)

            IF (ij_index <= rec_ij_index) THEN
               ! we know that proc_receive has something to send for us, let's see what
               ij_counter_rec = &
                  (ij_index - MIN(1, integ_group_pos2color_sub(proc_receive)))*ngroup + integ_group_pos2color_sub(proc_receive)

               rec_i = ij_map(spin, ij_counter_rec)

               BI_C_rec(1:my_B_size, 1:my_block_size, 1:my_group_L_size) => &
                  buffer_1D(offset + 1:offset + INT(my_B_size, int_8)*my_block_size*my_group_L_size)
               BI_C_rec = 0.0_dp

               CALL comm_exchange%sendrecv(BI_C_send, proc_send, &
                                           BI_C_rec, proc_receive, tag)

               CALL timeset(routineN//"_comm2_w", handle2)
               DO irep = 0, num_integ_group - 1
                  start_point = ranges_info_array(3, irep, comm_exchange%mepos)
                  end_point = ranges_info_array(4, irep, comm_exchange%mepos)
!$OMP PARALLEL WORKSHARE DEFAULT(NONE) &
!$OMP                    SHARED(start_point,end_point,my_block_size,&
!$OMP                           Gamma_P_ia,rec_i,iiB,my_B_size,BI_C_rec)
                  Gamma_P_ia(:, rec_i:rec_i + my_block_size - 1, start_point:end_point) = &
                     Gamma_P_ia(:, rec_i:rec_i + my_block_size - 1, start_point:end_point) + &
                     BI_C_rec(1:my_B_size, :, start_point:end_point)
!$OMP END PARALLEL WORKSHARE
               END DO
               CALL timestop(handle2)

            ELSE
               ! we have something to send but nothing to receive
               CALL comm_exchange%send(BI_C_send, proc_send, tag)

            END IF

         END DO

      ELSE
         ! noting to send check if we have to receive
         DO proc_shift = 1, comm_exchange%num_pe - 1
            proc_send = MODULO(comm_exchange%mepos + proc_shift, comm_exchange%num_pe)
            proc_receive = MODULO(comm_exchange%mepos - proc_shift, comm_exchange%num_pe)
            rec_ij_index = num_ij_pairs(proc_receive)

            IF (ij_index <= rec_ij_index) THEN
               ! we know that proc_receive has something to send for us, let's see what
               ij_counter_rec = &
                  (ij_index - MIN(1, integ_group_pos2color_sub(proc_receive)))*ngroup + integ_group_pos2color_sub(proc_receive)

               rec_i = ij_map(spin, ij_counter_rec)

               BI_C_rec(1:my_B_size, 1:my_block_size, 1:my_group_L_size) => &
                  buffer_1D(1:INT(my_B_size, int_8)*my_block_size*my_group_L_size)

               BI_C_rec = 0.0_dp

               CALL comm_exchange%recv(BI_C_rec, proc_receive, tag)

               CALL timeset(routineN//"_comm2_w", handle2)
               DO irep = 0, num_integ_group - 1
                  start_point = ranges_info_array(3, irep, comm_exchange%mepos)
                  end_point = ranges_info_array(4, irep, comm_exchange%mepos)
#if !defined(__INTEL_LLVM_COMPILER) || (20250000 <= __INTEL_LLVM_COMPILER)
!$OMP PARALLEL WORKSHARE DEFAULT(NONE) &
!$OMP                    SHARED(start_point,end_point,my_block_size,&
!$OMP                           Gamma_P_ia,rec_i,my_B_size,BI_C_rec)
#endif
                  Gamma_P_ia(:, rec_i:rec_i + my_block_size - 1, start_point:end_point) = &
                     Gamma_P_ia(:, rec_i:rec_i + my_block_size - 1, start_point:end_point) + &
                     BI_C_rec(1:my_B_size, :, start_point:end_point)
#if !defined(__INTEL_LLVM_COMPILER) || (20250000 <= __INTEL_LLVM_COMPILER)
!$OMP END PARALLEL WORKSHARE
#endif
               END DO
               CALL timestop(handle2)

            END IF
         END DO

      END IF
      CALL timestop(handle)

   END SUBROUTINE mp2_redistribute_gamma

! **************************************************************************************************
!> \brief ...
!> \param mp2_env ...
!> \param Eigenval ...
!> \param homo ...
!> \param virtual ...
!> \param open_shell ...
!> \param beta_beta ...
!> \param Bib_C ...
!> \param unit_nr ...
!> \param dimen_RI ...
!> \param my_B_size ...
!> \param ngroup ...
!> \param my_group_L_size ...
!> \param color_sub ...
!> \param ranges_info_array ...
!> \param comm_exchange ...
!> \param para_env_sub ...
!> \param para_env ...
!> \param my_B_virtual_start ...
!> \param my_B_virtual_end ...
!> \param sizes_array ...
!> \param gd_B_virtual ...
!> \param integ_group_pos2color_sub ...
!> \param dgemm_counter ...
!> \param buffer_1D ...
! **************************************************************************************************
   SUBROUTINE quasi_degenerate_P_ij(mp2_env, Eigenval, homo, virtual, open_shell, &
                                    beta_beta, Bib_C, unit_nr, dimen_RI, &
                                    my_B_size, ngroup, my_group_L_size, &
                                    color_sub, ranges_info_array, comm_exchange, para_env_sub, para_env, &
                                    my_B_virtual_start, my_B_virtual_end, sizes_array, gd_B_virtual, &
                                    integ_group_pos2color_sub, dgemm_counter, buffer_1D)
      TYPE(mp2_type)                                     :: mp2_env
      REAL(KIND=dp), DIMENSION(:, :), INTENT(IN)         :: Eigenval
      INTEGER, DIMENSION(:), INTENT(IN)                  :: homo, virtual
      LOGICAL, INTENT(IN)                                :: open_shell, beta_beta
      TYPE(three_dim_real_array), DIMENSION(:), &
         INTENT(IN)                                      :: BIb_C
      INTEGER, INTENT(IN)                                :: unit_nr, dimen_RI
      INTEGER, DIMENSION(:), INTENT(IN)                  :: my_B_size
      INTEGER, INTENT(IN)                                :: ngroup, my_group_L_size, color_sub
      INTEGER, ALLOCATABLE, DIMENSION(:, :, :), &
         INTENT(IN)                                      :: ranges_info_array
      TYPE(mp_comm_type), INTENT(IN)                     :: comm_exchange
      TYPE(mp_para_env_type), INTENT(IN)                 :: para_env_sub, para_env
      INTEGER, DIMENSION(:), INTENT(IN)                  :: my_B_virtual_start, my_B_virtual_end
      INTEGER, ALLOCATABLE, DIMENSION(:), INTENT(IN)     :: sizes_array
      TYPE(group_dist_d1_type), DIMENSION(:), INTENT(IN) :: gd_B_virtual
      INTEGER, ALLOCATABLE, DIMENSION(:), INTENT(IN)     :: integ_group_pos2color_sub
      TYPE(dgemm_counter_type), INTENT(INOUT)            :: dgemm_counter
      REAL(KIND=dp), CONTIGUOUS, DIMENSION(:), TARGET    :: buffer_1D

      CHARACTER(LEN=*), PARAMETER :: routineN = 'quasi_degenerate_P_ij'

      INTEGER :: a, a_global, b, b_global, block_size, decil, handle, handle2, ijk_counter, &
         ijk_counter_send, ijk_index, ispin, kkB, kspin, max_block_size, max_ijk, my_i, my_ijk, &
         my_j, my_k, my_last_k(2), my_virtual, nspins, proc_receive, proc_send, proc_shift, &
         rec_B_size, rec_B_virtual_end, rec_B_virtual_start, rec_L_size, send_B_size, &
         send_B_virtual_end, send_B_virtual_start, send_i, send_ijk_index, send_j, send_k, &
         size_B_i, size_B_k, tag, tag2
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: num_ijk
      INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: ijk_map, send_last_k
      LOGICAL                                            :: alpha_beta, do_recv_i, do_recv_j, &
                                                            do_recv_k, do_send_i, do_send_j, &
                                                            do_send_k
      REAL(KIND=dp)                                      :: amp_fac, P_ij_elem, t_new, t_start
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :), &
         TARGET                                          :: local_ab, local_aL_i, local_aL_j, t_ab
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: local_aL_k
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: BI_C_rec, external_ab, external_aL
      REAL(KIND=dp), DIMENSION(:, :, :), POINTER         :: BI_C_rec_3D

      CALL timeset(routineN//"_ij_sing", handle)

      tag = 44
      tag2 = 45

      nspins = SIZE(BIb_C)
      alpha_beta = (nspins == 2)

      ! Set amplitude factor
      amp_fac = mp2_env%scale_S + mp2_env%scale_T
      IF (open_shell) amp_fac = mp2_env%scale_T

      ALLOCATE (send_last_k(2, comm_exchange%num_pe - 1))

      ! Loop(s) over orbital triplets
      DO ispin = 1, nspins
         size_B_i = my_B_size(ispin)
         IF (ispin == 1 .AND. alpha_beta) THEN
            kspin = 2
         ELSE
            kspin = 1
         END IF
         size_B_k = my_B_size(kspin)

         ! Find the number of quasi-degenerate orbitals and orbital triplets

         CALL Find_quasi_degenerate_ij(my_ijk, homo(ispin), homo(kspin), Eigenval(:, ispin), mp2_env, ijk_map, unit_nr, ngroup, &
                                       .NOT. beta_beta .AND. ispin /= 2, comm_exchange, num_ijk, max_ijk, color_sub, &
                                       SIZE(buffer_1D), my_group_L_size, size_B_k, para_env, virtual(ispin), size_B_i)

         my_virtual = virtual(ispin)
         IF (SIZE(ijk_map, 2) > 0) THEN
            max_block_size = ijk_map(4, 1)
         ELSE
            max_block_size = 1
         END IF

         ALLOCATE (local_aL_i(dimen_RI, size_B_i))
         ALLOCATE (local_aL_j(dimen_RI, size_B_i))
         ALLOCATE (local_aL_k(dimen_RI, size_B_k, max_block_size))
         ALLOCATE (t_ab(my_virtual, size_B_k))

         my_last_k = -1
         send_last_k = -1

         t_start = m_walltime()
         DO ijk_index = 1, max_ijk

            ! Prediction is unreliable if we are in the first step of the loop
            IF (unit_nr > 0 .AND. ijk_index > 1) THEN
               decil = ijk_index*10/max_ijk
               IF (decil /= (ijk_index - 1)*10/max_ijk) THEN
                  t_new = m_walltime()
                  t_new = (t_new - t_start)/60.0_dp*(max_ijk - ijk_index + 1)/(ijk_index - 1)
                  WRITE (unit_nr, FMT="(T3,A)") "Percentage of finished loop: "// &
                     cp_to_string(decil*10)//". Minutes left: "//cp_to_string(t_new)
                  CALL m_flush(unit_nr)
               END IF
            END IF

            IF (ijk_index <= my_ijk) THEN
               ! work to be done
               ijk_counter = (ijk_index - MIN(1, color_sub))*ngroup + color_sub
               my_i = ijk_map(1, ijk_counter)
               my_j = ijk_map(2, ijk_counter)
               my_k = ijk_map(3, ijk_counter)
               block_size = ijk_map(4, ijk_counter)

               do_recv_i = (ispin /= kspin) .OR. my_i < my_k .OR. my_i > my_k + block_size - 1
               do_recv_j = (ispin /= kspin) .OR. my_j < my_k .OR. my_j > my_k + block_size - 1
               do_recv_k = my_k /= my_last_k(1) .OR. my_k + block_size - 1 /= my_last_k(2)
               my_last_k(1) = my_k
               my_last_k(2) = my_k + block_size - 1

               local_aL_i = 0.0_dp
               IF (do_recv_i) THEN
                  CALL fill_local_i_aL_2D(local_al_i, ranges_info_array(:, :, comm_exchange%mepos), &
                                          BIb_C(ispin)%array(:, :, my_i))
               END IF

               local_aL_j = 0.0_dp
               IF (do_recv_j) THEN
                  CALL fill_local_i_aL_2D(local_al_j, ranges_info_array(:, :, comm_exchange%mepos), &
                                          BIb_C(ispin)%array(:, :, my_j))
               END IF

               IF (do_recv_k) THEN
                  local_aL_k = 0.0_dp
                  CALL fill_local_i_aL(local_aL_k(:, :, 1:block_size), ranges_info_array(:, :, comm_exchange%mepos), &
                                       BIb_C(kspin)%array(:, :, my_k:my_k + block_size - 1))
               END IF

               CALL timeset(routineN//"_comm", handle2)
               DO proc_shift = 1, comm_exchange%num_pe - 1
                  proc_send = MODULO(comm_exchange%mepos + proc_shift, comm_exchange%num_pe)
                  proc_receive = MODULO(comm_exchange%mepos - proc_shift, comm_exchange%num_pe)

                  send_ijk_index = num_ijk(proc_send)

                  rec_L_size = sizes_array(proc_receive)
                  BI_C_rec(1:rec_L_size, 1:size_B_i) => buffer_1D(1:INT(rec_L_size, KIND=int_8)*size_B_i)

                  do_send_i = .FALSE.
                  do_send_j = .FALSE.
                  do_send_k = .FALSE.
                  IF (ijk_index <= send_ijk_index) THEN
                     ! something to send
                     ijk_counter_send = (ijk_index - MIN(1, integ_group_pos2color_sub(proc_send)))* &
                                        ngroup + integ_group_pos2color_sub(proc_send)
                     send_i = ijk_map(1, ijk_counter_send)
                     send_j = ijk_map(2, ijk_counter_send)
                     send_k = ijk_map(3, ijk_counter_send)

                     do_send_i = (ispin /= kspin) .OR. send_i < send_k .OR. send_i > send_k + block_size - 1
                     do_send_j = (ispin /= kspin) .OR. send_j < send_k .OR. send_j > send_k + block_size - 1
                     do_send_k = send_k /= send_last_k(1, proc_shift) .OR. send_k + block_size - 1 /= send_last_k(2, proc_shift)
                     send_last_k(1, proc_shift) = send_k
                     send_last_k(2, proc_shift) = send_k + block_size - 1
                  END IF

                  ! occupied i
                  BI_C_rec = 0.0_dp
                  IF (do_send_i) THEN
                  IF (do_recv_i) THEN
                     CALL comm_exchange%sendrecv(BIb_C(ispin)%array(:, :, send_i), proc_send, &
                                                 BI_C_rec, proc_receive, tag)
                  ELSE
                     CALL comm_exchange%send(BIb_C(ispin)%array(:, :, send_i), proc_send, tag)
                  END IF
                  ELSE IF (do_recv_i) THEN
                  CALL comm_exchange%recv(BI_C_rec, proc_receive, tag)
                  END IF
                  IF (do_recv_i) THEN
                     CALL fill_local_i_aL_2D(local_al_i, ranges_info_array(:, :, proc_receive), BI_C_rec)
                  END IF

                  ! occupied j
                  BI_C_rec = 0.0_dp
                  IF (do_send_j) THEN
                  IF (do_recv_j) THEN
                     CALL comm_exchange%sendrecv(BIb_C(ispin)%array(:, :, send_j), proc_send, &
                                                 BI_C_rec, proc_receive, tag)
                  ELSE
                     CALL comm_exchange%send(BIb_C(ispin)%array(:, :, send_j), proc_send, tag)
                  END IF
                  ELSE IF (do_recv_j) THEN
                  CALL comm_exchange%recv(BI_C_rec, proc_receive, tag)
                  END IF
                  IF (do_recv_j) THEN
                     CALL fill_local_i_aL_2D(local_al_j, ranges_info_array(:, :, proc_receive), BI_C_rec)
                  END IF

                  ! occupied k
                  BI_C_rec_3D(1:rec_L_size, 1:size_B_k, 1:block_size) => &
                     buffer_1D(1:INT(rec_L_size, KIND=int_8)*size_B_k*block_size)
                  IF (do_send_k) THEN
                  IF (do_recv_k) THEN
                     CALL comm_exchange%sendrecv(BIb_C(kspin)%array(:, :, send_k:send_k + block_size - 1), proc_send, &
                                                 BI_C_rec_3D, proc_receive, tag)
                  ELSE
                     CALL comm_exchange%send(BI_C_rec, proc_receive, tag)
                  END IF
                  ELSE IF (do_recv_k) THEN
                  CALL comm_exchange%recv(BI_C_rec_3D, proc_receive, tag)
                  END IF
                  IF (do_recv_k) THEN
                     CALL fill_local_i_aL(local_al_k(:, :, 1:block_size), ranges_info_array(:, :, proc_receive), BI_C_rec_3D)
                  END IF
               END DO

               IF (.NOT. do_recv_i) local_aL_i(:, :) = local_aL_k(:, :, my_i - my_k + 1)
               IF (.NOT. do_recv_j) local_aL_j(:, :) = local_aL_k(:, :, my_j - my_k + 1)
               CALL timestop(handle2)

               ! expand integrals
               DO kkB = 1, block_size
                  CALL timeset(routineN//"_exp_ik", handle2)
                  CALL dgemm_counter_start(dgemm_counter)
                  ALLOCATE (local_ab(my_virtual, size_B_k))
                  local_ab = 0.0_dp
                  CALL mp2_env%local_gemm_ctx%gemm('T', 'N', size_B_i, size_B_k, dimen_RI, 1.0_dp, &
                                                   local_aL_i, dimen_RI, local_aL_k(:, :, kkB), dimen_RI, &
                                          0.0_dp, local_ab(my_B_virtual_start(ispin):my_B_virtual_end(ispin), 1:size_B_k), size_B_i)
                  DO proc_shift = 1, para_env_sub%num_pe - 1
                     proc_send = MODULO(para_env_sub%mepos + proc_shift, para_env_sub%num_pe)
                     proc_receive = MODULO(para_env_sub%mepos - proc_shift, para_env_sub%num_pe)

                     CALL get_group_dist(gd_B_virtual(ispin), proc_receive, rec_B_virtual_start, rec_B_virtual_end, rec_B_size)

                     external_aL(1:dimen_RI, 1:rec_B_size) => buffer_1D(1:INT(dimen_RI, KIND=int_8)*rec_B_size)

                     CALL comm_exchange%sendrecv(local_aL_i, proc_send, &
                                                 external_aL, proc_receive, tag)

                     CALL mp2_env%local_gemm_ctx%gemm('T', 'N', rec_B_size, size_B_k, dimen_RI, 1.0_dp, &
                                                      external_aL, dimen_RI, local_aL_k(:, :, kkB), dimen_RI, &
                                                    0.0_dp, local_ab(rec_B_virtual_start:rec_B_virtual_end, 1:size_B_k), rec_B_size)
                  END DO
                  CALL dgemm_counter_stop(dgemm_counter, my_virtual, size_B_k, dimen_RI)
                  CALL timestop(handle2)

                  ! Amplitudes
                  CALL timeset(routineN//"_tab", handle2)
                  t_ab = 0.0_dp
                  ! Alpha-alpha, beta-beta and closed shell
                  IF (.NOT. alpha_beta) THEN
                     DO b = 1, size_B_k
                        b_global = b + my_B_virtual_start(1) - 1
                        DO a = 1, my_B_size(1)
                           a_global = a + my_B_virtual_start(1) - 1
                           t_ab(a_global, b) = (amp_fac*local_ab(a_global, b) - mp2_env%scale_T*local_ab(b_global, a))/ &
                                               (Eigenval(my_i, 1) + Eigenval(my_k + kkB - 1, 1) &
                                                - Eigenval(homo(1) + a_global, 1) - Eigenval(homo(1) + b_global, 1))
                        END DO
                     END DO
                  ELSE
                     DO b = 1, size_B_k
                        b_global = b + my_B_virtual_start(kspin) - 1
                        DO a = 1, my_B_size(ispin)
                           a_global = a + my_B_virtual_start(ispin) - 1
                           t_ab(a_global, b) = mp2_env%scale_S*local_ab(a_global, b)/ &
                                               (Eigenval(my_i, ispin) + Eigenval(my_k + kkB - 1, kspin) &
                                                - Eigenval(homo(ispin) + a_global, ispin) - Eigenval(homo(kspin) + b_global, kspin))
                        END DO
                     END DO
                  END IF

                  IF (.NOT. alpha_beta) THEN
                     DO proc_shift = 1, para_env_sub%num_pe - 1
                        proc_send = MODULO(para_env_sub%mepos + proc_shift, para_env_sub%num_pe)
                        proc_receive = MODULO(para_env_sub%mepos - proc_shift, para_env_sub%num_pe)
                        CALL get_group_dist(gd_B_virtual(1), proc_receive, rec_B_virtual_start, rec_B_virtual_end, rec_B_size)
                        CALL get_group_dist(gd_B_virtual(1), proc_send, send_B_virtual_start, send_B_virtual_end, send_B_size)

                        external_ab(1:size_B_i, 1:rec_B_size) => buffer_1D(1:INT(size_B_i, KIND=int_8)*rec_B_size)
                        CALL para_env_sub%sendrecv(local_ab(send_B_virtual_start:send_B_virtual_end, 1:size_B_k), proc_send, &
                                                   external_ab(1:size_B_i, 1:rec_B_size), proc_receive, tag)

                        DO b = 1, my_B_size(1)
                           b_global = b + my_B_virtual_start(1) - 1
                           DO a = 1, rec_B_size
                              a_global = a + rec_B_virtual_start - 1
                              t_ab(a_global, b) = (amp_fac*local_ab(a_global, b) - mp2_env%scale_T*external_ab(b, a))/ &
                                                  (Eigenval(my_i, 1) + Eigenval(my_k + kkB - 1, 1) &
                                                   - Eigenval(homo(1) + a_global, 1) - Eigenval(homo(1) + b_global, 1))
                           END DO
                        END DO
                     END DO
                  END IF
                  CALL timestop(handle2)

                  ! Expand the second set of integrals
                  CALL timeset(routineN//"_exp_jk", handle2)
                  local_ab = 0.0_dp
                  CALL dgemm_counter_start(dgemm_counter)
                  CALL mp2_env%local_gemm_ctx%gemm('T', 'N', size_B_i, size_B_k, dimen_RI, 1.0_dp, &
                                                   local_aL_j, dimen_RI, local_aL_k(:, :, kkB), dimen_RI, &
                                          0.0_dp, local_ab(my_B_virtual_start(ispin):my_B_virtual_end(ispin), 1:size_B_k), size_B_i)
                  DO proc_shift = 1, para_env_sub%num_pe - 1
                     proc_send = MODULO(para_env_sub%mepos + proc_shift, para_env_sub%num_pe)
                     proc_receive = MODULO(para_env_sub%mepos - proc_shift, para_env_sub%num_pe)

                     CALL get_group_dist(gd_B_virtual(ispin), proc_receive, rec_B_virtual_start, rec_B_virtual_end, rec_B_size)

                     external_aL(1:dimen_RI, 1:rec_B_size) => buffer_1D(1:INT(dimen_RI, KIND=int_8)*rec_B_size)

                     CALL comm_exchange%sendrecv(local_aL_j, proc_send, &
                                                 external_aL, proc_receive, tag)
                     CALL mp2_env%local_gemm_ctx%gemm('T', 'N', rec_B_size, size_B_k, dimen_RI, 1.0_dp, &
                                                      external_aL, dimen_RI, local_aL_k(:, :, kkB), dimen_RI, &
                                                    0.0_dp, local_ab(rec_B_virtual_start:rec_B_virtual_end, 1:size_B_k), rec_B_size)
                  END DO
                  CALL dgemm_counter_stop(dgemm_counter, my_virtual, size_B_k, dimen_RI)
                  CALL timestop(handle2)

                  CALL timeset(routineN//"_Pij", handle2)
                  DO b = 1, size_B_k
                     b_global = b + my_B_virtual_start(kspin) - 1
                     DO a = 1, my_B_size(ispin)
                        a_global = a + my_B_virtual_start(ispin) - 1
                        local_ab(a_global, b) = &
                           local_ab(a_global, b)/(Eigenval(my_j, ispin) + Eigenval(my_k + kkB - 1, kspin) &
                                                - Eigenval(homo(ispin) + a_global, ispin) - Eigenval(homo(kspin) + b_global, kspin))
                     END DO
                  END DO
                  !
                  P_ij_elem = SUM(local_ab*t_ab)
                  DEALLOCATE (local_ab)
                  IF ((.NOT. open_shell) .AND. (.NOT. alpha_beta)) THEN
                     P_ij_elem = P_ij_elem*2.0_dp
                  END IF
                  IF (beta_beta) THEN
                     mp2_env%ri_grad%P_ij(2)%array(my_i, my_j) = mp2_env%ri_grad%P_ij(2)%array(my_i, my_j) - P_ij_elem
                     mp2_env%ri_grad%P_ij(2)%array(my_j, my_i) = mp2_env%ri_grad%P_ij(2)%array(my_j, my_i) - P_ij_elem
                  ELSE
                     mp2_env%ri_grad%P_ij(ispin)%array(my_i, my_j) = mp2_env%ri_grad%P_ij(ispin)%array(my_i, my_j) - P_ij_elem
                     mp2_env%ri_grad%P_ij(ispin)%array(my_j, my_i) = mp2_env%ri_grad%P_ij(ispin)%array(my_j, my_i) - P_ij_elem
                  END IF
                  CALL timestop(handle2)
               END DO
            ELSE
               CALL timeset(routineN//"_comm", handle2)
               ! no work to be done, possible messeges to be exchanged
               DO proc_shift = 1, comm_exchange%num_pe - 1
                  proc_send = MODULO(comm_exchange%mepos + proc_shift, comm_exchange%num_pe)
                  proc_receive = MODULO(comm_exchange%mepos - proc_shift, comm_exchange%num_pe)

                  send_ijk_index = num_ijk(proc_send)

                  IF (ijk_index <= send_ijk_index) THEN
                     ! somethig to send
                     ijk_counter_send = (ijk_index - MIN(1, integ_group_pos2color_sub(proc_send)))*ngroup + &
                                        integ_group_pos2color_sub(proc_send)
                     send_i = ijk_map(1, ijk_counter_send)
                     send_j = ijk_map(2, ijk_counter_send)
                     send_k = ijk_map(3, ijk_counter_send)
                     block_size = ijk_map(4, ijk_counter_send)

                     do_send_i = (ispin /= kspin) .OR. send_i < send_k .OR. send_i > send_k + block_size - 1
                     do_send_j = (ispin /= kspin) .OR. send_j < send_k .OR. send_j > send_k + block_size - 1
                     ! occupied i
                     IF (do_send_i) THEN
                        CALL comm_exchange%send(BIb_C(ispin)%array(:, :, send_i), proc_send, tag)
                     END IF
                     ! occupied j
                     IF (do_send_j) THEN
                        CALL comm_exchange%send(BIb_C(ispin)%array(:, :, send_j), proc_send, tag)
                     END IF
                     ! occupied k
                     CALL comm_exchange%send(BIb_C(kspin)%array(:, :, send_k:send_k + block_size - 1), proc_send, tag)
                  END IF

               END DO ! proc loop
               CALL timestop(handle2)
            END IF
         END DO ! ijk_index loop
         DEALLOCATE (local_aL_i)
         DEALLOCATE (local_aL_j)
         DEALLOCATE (local_aL_k)
         DEALLOCATE (t_ab)
         DEALLOCATE (ijk_map)
      END DO ! over number of loops (ispin)
      CALL timestop(handle)

   END SUBROUTINE Quasi_degenerate_P_ij

! **************************************************************************************************
!> \brief ...
!> \param my_ijk ...
!> \param homo ...
!> \param homo_beta ...
!> \param Eigenval ...
!> \param mp2_env ...
!> \param ijk_map ...
!> \param unit_nr ...
!> \param ngroup ...
!> \param do_print_alpha ...
!> \param comm_exchange ...
!> \param num_ijk ...
!> \param max_ijk ...
!> \param color_sub ...
!> \param buffer_size ...
!> \param my_group_L_size ...
!> \param B_size_k ...
!> \param para_env ...
!> \param virtual ...
!> \param B_size_i ...
! **************************************************************************************************
   SUBROUTINE Find_quasi_degenerate_ij(my_ijk, homo, homo_beta, Eigenval, mp2_env, ijk_map, unit_nr, ngroup, &
                                       do_print_alpha, comm_exchange, num_ijk, max_ijk, color_sub, &
                                       buffer_size, my_group_L_size, B_size_k, para_env, virtual, B_size_i)

      INTEGER, INTENT(OUT)                               :: my_ijk
      INTEGER, INTENT(IN)                                :: homo, homo_beta
      REAL(KIND=dp), DIMENSION(:), INTENT(IN)            :: Eigenval
      TYPE(mp2_type), INTENT(IN)                         :: mp2_env
      INTEGER, ALLOCATABLE, DIMENSION(:, :), INTENT(OUT) :: ijk_map
      INTEGER, INTENT(IN)                                :: unit_nr, ngroup
      LOGICAL, INTENT(IN)                                :: do_print_alpha
      TYPE(mp_comm_type), INTENT(IN)                     :: comm_exchange
      INTEGER, ALLOCATABLE, DIMENSION(:), INTENT(OUT)    :: num_ijk
      INTEGER, INTENT(OUT)                               :: max_ijk
      INTEGER, INTENT(IN)                                :: color_sub, buffer_size, my_group_L_size, &
                                                            B_size_k
      TYPE(mp_para_env_type), INTENT(IN)                 :: para_env
      INTEGER, INTENT(IN)                                :: virtual, B_size_i

      INTEGER :: block_size, communication_steps, communication_volume, iib, ij_counter, &
         ijk_counter, jjb, kkb, max_block_size, max_num_k_blocks, min_communication_volume, &
         my_steps, num_k_blocks, num_sing_ij, total_ijk
      INTEGER(KIND=int_8)                                :: mem
      LOGICAL, ALLOCATABLE, DIMENSION(:, :)              :: ijk_marker

      ALLOCATE (num_ijk(0:comm_exchange%num_pe - 1))

      num_sing_ij = 0
      DO iiB = 1, homo
         ! diagonal elements already updated
         DO jjB = iiB + 1, homo
            IF (ABS(Eigenval(jjB) - Eigenval(iiB)) < mp2_env%ri_grad%eps_canonical) &
               num_sing_ij = num_sing_ij + 1
         END DO
      END DO

      IF (unit_nr > 0) THEN
      IF (do_print_alpha) THEN
         WRITE (UNIT=unit_nr, FMT="(T3,A,T75,i6)") &
            "MO_INFO| Number of ij pairs below EPS_CANONICAL:", num_sing_ij
      ELSE
         WRITE (UNIT=unit_nr, FMT="(T3,A,T75,i6)") &
            "MO_INFO| Number of ij pairs (spin beta) below EPS_CANONICAL:", num_sing_ij
      END IF
      END IF

      ! Determine the block size, first guess: use available buffer
      max_block_size = buffer_size/(my_group_L_size*B_size_k)

      ! Second limit: memory
      CALL m_memory(mem)
      ! Convert to number of doubles
      mem = mem/8
      ! Remove local_ab (2x) and local_aL_i (2x)
      mem = mem - 2_int_8*(virtual*B_size_k + B_size_i*my_group_L_size)
      max_block_size = MIN(max_block_size, MAX(1, INT(mem/(my_group_L_size*B_size_k), KIND(max_block_size))))

      ! Exchange the limit
      CALL para_env%min(max_block_size)

      ! Find now the block size which minimizes the communication volume and then the number of communication steps
      block_size = 1
      min_communication_volume = 3*homo_beta*num_sing_ij
      communication_steps = 3*homo_beta*num_sing_ij
      DO iiB = max_block_size, 2, -1
         max_num_k_blocks = homo_beta/iiB*num_sing_ij
         num_k_blocks = max_num_k_blocks - MOD(max_num_k_blocks, ngroup)
         communication_volume = num_k_blocks*(2 + iiB) + 3*(homo_beta*num_sing_ij - iiB*num_k_blocks)
         my_steps = num_k_blocks + homo_beta*num_sing_ij - iiB*num_k_blocks
         IF (communication_volume < min_communication_volume) THEN
            block_size = iiB
            min_communication_volume = communication_volume
            communication_steps = my_steps
         ELSE IF (communication_volume == min_communication_volume .AND. my_steps < communication_steps) THEN
            block_size = iiB
            communication_steps = my_steps
         END IF
      END DO

      IF (unit_nr > 0) THEN
         WRITE (UNIT=unit_nr, FMT="(T3,A,T75,i6)") &
            "MO_INFO| Block size:", block_size
         CALL m_flush(unit_nr)
      END IF

      ! Calculate number of large blocks
      max_num_k_blocks = homo_beta/block_size*num_sing_ij
      num_k_blocks = max_num_k_blocks - MOD(max_num_k_blocks, ngroup)

      total_ijk = num_k_blocks + homo_beta*num_sing_ij - num_k_blocks*block_size
      ALLOCATE (ijk_map(4, total_ijk))
      ijk_map = 0
      ALLOCATE (ijk_marker(homo_beta, num_sing_ij))
      ijk_marker = .TRUE.

      my_ijk = 0
      ijk_counter = 0
      ij_counter = 0
      DO iiB = 1, homo
         ! diagonal elements already updated
         DO jjB = iiB + 1, homo
            IF (ABS(Eigenval(jjB) - Eigenval(iiB)) >= mp2_env%ri_grad%eps_canonical) CYCLE
            ij_counter = ij_counter + 1
            DO kkB = 1, homo_beta - MOD(homo_beta, block_size), block_size
               IF (ijk_counter + 1 > num_k_blocks) EXIT
               ijk_counter = ijk_counter + 1
               ijk_marker(kkB:kkB + block_size - 1, ij_counter) = .FALSE.
               ijk_map(1, ijk_counter) = iiB
               ijk_map(2, ijk_counter) = jjB
               ijk_map(3, ijk_counter) = kkB
               ijk_map(4, ijk_counter) = block_size
               IF (MOD(ijk_counter, ngroup) == color_sub) my_ijk = my_ijk + 1
            END DO
         END DO
      END DO
      ij_counter = 0
      DO iiB = 1, homo
         ! diagonal elements already updated
         DO jjB = iiB + 1, homo
            IF (ABS(Eigenval(jjB) - Eigenval(iiB)) >= mp2_env%ri_grad%eps_canonical) CYCLE
            ij_counter = ij_counter + 1
            DO kkB = 1, homo_beta
               IF (ijk_marker(kkB, ij_counter)) THEN
                  ijk_counter = ijk_counter + 1
                  ijk_map(1, ijk_counter) = iiB
                  ijk_map(2, ijk_counter) = jjB
                  ijk_map(3, ijk_counter) = kkB
                  ijk_map(4, ijk_counter) = 1
                  IF (MOD(ijk_counter, ngroup) == color_sub) my_ijk = my_ijk + 1
               END IF
            END DO
         END DO
      END DO

      DEALLOCATE (ijk_marker)

      CALL comm_exchange%allgather(my_ijk, num_ijk)
      max_ijk = MAXVAL(num_ijk)

   END SUBROUTINE Find_quasi_degenerate_ij

END MODULE mp2_ri_gpw
